diff --git a/django-stubs/db/models/fields/__init__.pyi b/django-stubs/db/models/fields/__init__.pyi index 6add1186b..eff311126 100644 --- a/django-stubs/db/models/fields/__init__.pyi +++ b/django-stubs/db/models/fields/__init__.pyi @@ -406,4 +406,5 @@ class DurationField(Field[_ST, _GT]): _pyi_private_get_type: timedelta class BigAutoField(AutoField[_ST, _GT]): ... +class SmallAutoField(AutoField[_ST, _GT]): ... class CommaSeparatedIntegerField(CharField[_ST, _GT]): ... diff --git a/mypy_django_plugin/django/context.py b/mypy_django_plugin/django/context.py index 36ac0f869..67d6345d4 100644 --- a/mypy_django_plugin/django/context.py +++ b/mypy_django_plugin/django/context.py @@ -16,12 +16,13 @@ from django.db.models.sql.query import Query from django.utils.functional import cached_property from mypy.checker import TypeChecker +from mypy.nodes import TypeInfo from mypy.plugin import MethodContext from mypy.types import AnyType, Instance from mypy.types import Type as MypyType from mypy.types import TypeOfAny, UnionType -from mypy_django_plugin.lib import fullnames, helpers +from mypy_django_plugin.lib import chk_helpers, fullnames, helpers try: from django.contrib.postgres.fields import ArrayField @@ -52,14 +53,6 @@ def initialize_django(settings_module: str) -> Tuple['Apps', 'LazySettings']: # add current directory to sys.path sys.path.append(os.getcwd()) - def noop_class_getitem(cls, key): - return cls - - from django.db import models - - models.QuerySet.__class_getitem__ = classmethod(noop_class_getitem) # type: ignore - models.Manager.__class_getitem__ = classmethod(noop_class_getitem) # type: ignore - from django.conf import settings from django.apps import apps @@ -119,10 +112,10 @@ def get_model_fields(self, model_cls: Type[Model]) -> Iterator[Field]: if isinstance(field, Field): yield field - def get_model_relations(self, model_cls: Type[Model]) -> Iterator[ForeignObjectRel]: - for field in model_cls._meta.get_fields(): - if isinstance(field, ForeignObjectRel): - yield field + def get_model_relations(self, model_cls: Type[Model]) -> Iterator[Tuple[Optional[str], ForeignObjectRel]]: + for relation in model_cls._meta.get_fields(): + if isinstance(relation, ForeignObjectRel): + yield relation.get_accessor_name(), relation def get_field_lookup_exact_type(self, api: TypeChecker, field: Union[Field, ForeignObjectRel]) -> MypyType: if isinstance(field, (RelatedField, ForeignObjectRel)): @@ -222,11 +215,15 @@ def all_registered_model_classes(self) -> Set[Type[models.Model]]: def all_registered_model_class_fullnames(self) -> Set[str]: return {helpers.get_class_fullname(cls) for cls in self.all_registered_model_classes} + def is_model_subclass(self, info: TypeInfo) -> bool: + return (info.fullname in self.all_registered_model_class_fullnames + or info.has_base(fullnames.MODEL_CLASS_FULLNAME)) + def get_attname(self, field: Field) -> str: attname = field.attname return attname - def get_field_nullability(self, field: Union[Field, ForeignObjectRel], method: Optional[str]) -> bool: + def get_field_nullability(self, field: Union[Field, ForeignObjectRel], method: Optional[str] = None) -> bool: nullable = field.null if not nullable and isinstance(field, CharField) and field.blank: return True @@ -356,11 +353,11 @@ def resolve_lookup_expected_type(self, ctx: MethodContext, model_cls: Type[Model return AnyType(TypeOfAny.explicit) if lookup_cls is None or isinstance(lookup_cls, Exact): - return self.get_field_lookup_exact_type(helpers.get_typechecker_api(ctx), field) + return self.get_field_lookup_exact_type(chk_helpers.get_typechecker_api(ctx), field) assert lookup_cls is not None - lookup_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), lookup_cls) + lookup_info = helpers.lookup_class_typeinfo(chk_helpers.get_typechecker_api(ctx), lookup_cls) if lookup_info is None: return AnyType(TypeOfAny.explicit) @@ -370,7 +367,7 @@ def resolve_lookup_expected_type(self, ctx: MethodContext, model_cls: Type[Model # if it's Field, consider lookup_type a __get__ of current field if (isinstance(lookup_type, Instance) and lookup_type.type.fullname == fullnames.FIELD_FULLNAME): - field_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), field.__class__) + field_info = helpers.lookup_class_typeinfo(chk_helpers.get_typechecker_api(ctx), field.__class__) if field_info is None: return AnyType(TypeOfAny.explicit) lookup_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type', diff --git a/mypy_django_plugin/lib/chk_helpers.py b/mypy_django_plugin/lib/chk_helpers.py new file mode 100644 index 000000000..3e62d2bb0 --- /dev/null +++ b/mypy_django_plugin/lib/chk_helpers.py @@ -0,0 +1,126 @@ +from typing import Dict, List, Optional, Set, Union + +from mypy import checker +from mypy.checker import TypeChecker +from mypy.mro import calculate_mro +from mypy.nodes import ( + GDEF, MDEF, Block, ClassDef, Expression, MypyFile, SymbolTable, SymbolTableNode, TypeInfo, Var, +) +from mypy.plugin import ( + AttributeContext, CheckerPluginInterface, FunctionContext, MethodContext, +) +from mypy.types import AnyType, Instance, TupleType +from mypy.types import Type as MypyType +from mypy.types import TypedDictType, TypeOfAny + +from mypy_django_plugin.lib import helpers + + +def add_new_class_for_current_module(current_module: MypyFile, + name: str, + bases: List[Instance], + fields: Optional[Dict[str, MypyType]] = None + ) -> TypeInfo: + new_class_unique_name = checker.gen_unique_name(name, current_module.names) + + # make new class expression + classdef = ClassDef(new_class_unique_name, Block([])) + classdef.fullname = current_module.fullname + '.' + new_class_unique_name + + # make new TypeInfo + new_typeinfo = TypeInfo(SymbolTable(), classdef, current_module.fullname) + new_typeinfo.bases = bases + calculate_mro(new_typeinfo) + new_typeinfo.calculate_metaclass_type() + + # add fields + if fields: + for field_name, field_type in fields.items(): + var = Var(field_name, type=field_type) + var.info = new_typeinfo + var._fullname = new_typeinfo.fullname + '.' + field_name + new_typeinfo.names[field_name] = SymbolTableNode(MDEF, var, plugin_generated=True) + + classdef.info = new_typeinfo + current_module.names[new_class_unique_name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True) + return new_typeinfo + + +def make_oneoff_named_tuple(api: TypeChecker, name: str, fields: 'Dict[str, MypyType]') -> TupleType: + current_module = helpers.get_current_module(api) + namedtuple_info = add_new_class_for_current_module(current_module, name, + bases=[api.named_generic_type('typing.NamedTuple', [])], + fields=fields) + return TupleType(list(fields.values()), fallback=Instance(namedtuple_info, [])) + + +def make_tuple(api: 'TypeChecker', fields: List[MypyType]) -> TupleType: + # fallback for tuples is any builtins.tuple instance + fallback = api.named_generic_type('builtins.tuple', + [AnyType(TypeOfAny.special_form)]) + return TupleType(fields, fallback=fallback) + + +def make_oneoff_typeddict(api: CheckerPluginInterface, fields: 'Dict[str, MypyType]', + required_keys: Set[str]) -> TypedDictType: + object_type = api.named_generic_type('mypy_extensions._TypedDict', []) + typed_dict_type = TypedDictType(fields, # type: ignore + required_keys=required_keys, + fallback=object_type) + return typed_dict_type + + +def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionContext]) -> TypeChecker: + if not isinstance(ctx.api, TypeChecker): + raise ValueError('Not a TypeChecker') + return ctx.api + + +def check_types_compatible(ctx: Union[FunctionContext, MethodContext], + *, expected_type: MypyType, actual_type: MypyType, error_message: str) -> None: + api = get_typechecker_api(ctx) + api.check_subtype(actual_type, expected_type, + ctx.context, error_message, + 'got', 'expected') + + +def get_call_argument_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[Expression]: + """ + Return the expression for the specific argument. + This helper should only be used with non-star arguments. + """ + if name not in ctx.callee_arg_names: + return None + idx = ctx.callee_arg_names.index(name) + args = ctx.args[idx] + if len(args) != 1: + # Either an error or no value passed. + return None + return args[0] + + +def get_call_argument_type_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[MypyType]: + """Return the type for the specific argument. + + This helper should only be used with non-star arguments. + """ + if name not in ctx.callee_arg_names: + return None + idx = ctx.callee_arg_names.index(name) + arg_types = ctx.arg_types[idx] + if len(arg_types) != 1: + # Either an error or no value passed. + return None + return arg_types[0] + + +def add_new_sym_for_info(info: TypeInfo, *, name: str, sym_type: MypyType) -> None: + # type=: type of the variable itself + var = Var(name=name, type=sym_type) + # var.info: type of the object variable is bound to + var.info = info + var._fullname = info.fullname + '.' + name + var.is_initialized_in_class = True + var.is_inferred = True + info.names[name] = SymbolTableNode(MDEF, var, + plugin_generated=True) diff --git a/mypy_django_plugin/lib/generics.py b/mypy_django_plugin/lib/generics.py new file mode 100644 index 000000000..434a17c50 --- /dev/null +++ b/mypy_django_plugin/lib/generics.py @@ -0,0 +1,6 @@ +def make_classes_generic(*klasses: type) -> None: + for klass in klasses: + def fake_classgetitem(cls, *args, **kwargs): + return cls + + klass.__class_getitem__ = classmethod(fake_classgetitem) # type: ignore diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index c99f2f82c..fef30d77e 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -1,70 +1,382 @@ -from collections import OrderedDict +from abc import abstractmethod from typing import ( - TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union, + TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union, cast, ) -from django.db.models.fields import Field -from django.db.models.fields.related import RelatedField -from django.db.models.fields.reverse_related import ForeignObjectRel -from mypy import checker from mypy.checker import TypeChecker from mypy.mro import calculate_mro from mypy.nodes import ( - GDEF, MDEF, Argument, Block, ClassDef, Expression, FuncDef, MemberExpr, MypyFile, NameExpr, PlaceholderNode, - StrExpr, SymbolNode, SymbolTable, SymbolTableNode, TypeInfo, Var, + GDEF, Argument, Block, CallExpr, ClassDef, Context, Expression, FuncDef, MemberExpr, MypyFile, NameExpr, + PlaceholderNode, StrExpr, SymbolTable, SymbolTableNode, TypeInfo, Var, ) from mypy.plugin import ( - AttributeContext, CheckerPluginInterface, ClassDefContext, DynamicClassDefContext, FunctionContext, MethodContext, + AttributeContext, ClassDefContext, DynamicClassDefContext, FunctionContext, MethodContext, ) -from mypy.plugins.common import add_method +from mypy.plugins.common import add_method_to_class from mypy.semanal import SemanticAnalyzer -from mypy.types import AnyType, CallableType, Instance, NoneTyp, TupleType +from mypy.types import AnyType, CallableType, Instance, NoneTyp, ProperType from mypy.types import Type as MypyType -from mypy.types import TypedDictType, TypeOfAny, UnionType +from mypy.types import TypeOfAny, UnionType -from mypy_django_plugin.lib import fullnames +from mypy_django_plugin.transformers import new_helpers if TYPE_CHECKING: from mypy_django_plugin.django.context import DjangoContext + from mypy_django_plugin.main import NewSemanalDjangoPlugin +AnyPluginAPI = Union[TypeChecker, SemanticAnalyzer] -def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]: - return model_info.metadata.setdefault('django', {}) +class DjangoPluginCallback: + django_context: 'DjangoContext' -class IncompleteDefnException(Exception): - pass + def __init__(self, plugin: 'NewSemanalDjangoPlugin') -> None: + self.plugin = plugin + self.django_context = plugin.django_context + def new_typeinfo(self, name: str, bases: List[Instance]) -> TypeInfo: + class_def = ClassDef(name, Block([])) + class_def.fullname = self.qualified_name(name) -def lookup_fully_qualified_sym(fullname: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolTableNode]: - if '.' not in fullname: - return None - module, cls_name = fullname.rsplit('.', 1) + info = TypeInfo(SymbolTable(), class_def, self.get_current_module().fullname) + info.bases = bases + calculate_mro(info) + info.metaclass_type = info.calculate_metaclass_type() - module_file = all_modules.get(module) - if module_file is None: - return None - sym = module_file.names.get(cls_name) - if sym is None: - return None - return sym + class_def.info = info + return info + @abstractmethod + def get_current_module(self) -> MypyFile: + raise NotImplementedError() -def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolNode]: - sym = lookup_fully_qualified_sym(name, all_modules) - if sym is None: + @abstractmethod + def qualified_name(self, name: str) -> str: + raise NotImplementedError() + + +class SemanalPluginCallback(DjangoPluginCallback): + semanal_api: SemanticAnalyzer + + def build_defer_error_message(self, message: str) -> str: + return f'{self.__class__.__name__}: {message}' + + def defer_till_next_iteration(self, deferral_context: Optional[Context] = None, + *, + reason: Optional[str] = None) -> bool: + """ Returns False if cannot be deferred. """ + if self.semanal_api.final_iteration: + return False + self.semanal_api.defer(deferral_context) + # when pytest-mypy-plugins changes to incorporate verbose mypy logging, + # uncomment following line to allow better feedback from users on issues + # print(f'LOG: defer: {self.build_defer_error_message(reason)}') + return True + + def get_current_module(self) -> MypyFile: + return self.semanal_api.cur_mod_node + + def qualified_name(self, name: str) -> str: + return self.semanal_api.qualified_name(name) + + def lookup_typeinfo_or_defer(self, fullname: str, *, + deferral_context: Optional[Context] = None, + reason_for_defer: Optional[str] = None) -> Optional[TypeInfo]: + sym = self.plugin.lookup_fully_qualified(fullname) + if sym is None or sym.node is None or isinstance(sym.node, PlaceholderNode): + deferral_context = deferral_context or self.semanal_api.cur_mod_node + reason = reason_for_defer or f'{fullname!r} is not available for lookup' + if not self.defer_till_next_iteration(deferral_context, reason=reason): + raise new_helpers.TypeInfoNotFound(fullname) + return None + + if not isinstance(sym.node, TypeInfo): + raise ValueError(f'{fullname!r} does not correspond to TypeInfo') + + return sym.node + + def copy_method_to_another_class( + self, + ctx: ClassDefContext, + self_type: Instance, + new_method_name: str, + method_node: FuncDef) -> None: + if method_node.type is None: + if not self.semanal_api.final_iteration: + self.semanal_api.defer() + return + + arguments, return_type = build_unannotated_method_args(method_node) + add_method_to_class( + ctx.api, + ctx.cls, + new_method_name, + args=arguments, + return_type=return_type, + self_type=self_type) + return + + method_type = cast(CallableType, method_node.type) + if not isinstance(method_type, CallableType) and not self.defer_till_next_iteration( + reason='method_node.type is not CallableType'): + raise new_helpers.TypeInfoNotFound(method_node.fullname) + + arguments = [] + bound_return_type = self.semanal_api.anal_type( + method_type.ret_type, + allow_placeholder=True) + + if bound_return_type is None and self.defer_till_next_iteration(): + raise new_helpers.TypeInfoNotFound(method_node.fullname + ' return type') + + assert bound_return_type is not None + + if isinstance(bound_return_type, PlaceholderNode): + raise new_helpers.TypeInfoNotFound('return type ' + method_node.fullname) + + for arg_name, arg_type, original_argument in zip( + method_type.arg_names[1:], + method_type.arg_types[1:], + method_node.arguments[1:]): + bound_arg_type = self.semanal_api.anal_type(arg_type, allow_placeholder=True) + if bound_arg_type is None and self.defer_till_next_iteration(reason='bound_arg_type is None'): + error_msg = 'of {} argument of {}'.format(arg_name, method_node.fullname) + raise new_helpers.TypeInfoNotFound(error_msg) + + assert bound_arg_type is not None + + if isinstance(bound_arg_type, PlaceholderNode) and self.defer_till_next_iteration( + reason='bound_arg_type is None'): + raise new_helpers.TypeInfoNotFound('of ' + arg_name + ' argument of ' + method_node.fullname) + + var = Var( + name=original_argument.variable.name, + type=arg_type) + var.line = original_argument.variable.line + var.column = original_argument.variable.column + argument = Argument( + variable=var, + type_annotation=bound_arg_type, + initializer=original_argument.initializer, + kind=original_argument.kind) + argument.set_line(original_argument) + arguments.append(argument) + + add_method_to_class( + ctx.api, + ctx.cls, + new_method_name, + args=arguments, + return_type=bound_return_type, + self_type=self_type) + + def new_typeinfo(self, name: str, bases: List[Instance], module_fullname: Optional[str] = None) -> TypeInfo: + class_def = ClassDef(name, Block([])) + class_def.fullname = self.semanal_api.qualified_name(name) + + info = TypeInfo(SymbolTable(), class_def, + module_fullname or self.get_current_module().fullname) + info.bases = bases + calculate_mro(info) + info.metaclass_type = info.calculate_metaclass_type() + + class_def.info = info + return info + + +class DynamicClassPluginCallback(SemanalPluginCallback): + class_name: str + call_expr: CallExpr + + def __call__(self, ctx: DynamicClassDefContext) -> None: + self.class_name = ctx.name + self.call_expr = ctx.call + self.semanal_api = cast(SemanticAnalyzer, ctx.api) + self.create_new_dynamic_class() + + def generate_manager_info_and_module(self, base_manager_info: TypeInfo) -> Tuple[TypeInfo, MypyFile]: + new_manager_info = self.semanal_api.basic_new_typeinfo( + self.class_name, + basetype_or_fallback=Instance( + base_manager_info, + [AnyType(TypeOfAny.unannotated)]) + ) + new_manager_info.line = self.call_expr.line + new_manager_info.defn.line = self.call_expr.line + new_manager_info.metaclass_type = new_manager_info.calculate_metaclass_type() + + current_module = self.semanal_api.cur_mod_node + current_module.names[self.class_name] = SymbolTableNode( + GDEF, + new_manager_info, + plugin_generated=True) + return new_manager_info, current_module + + @abstractmethod + def create_new_dynamic_class(self) -> None: + raise NotImplementedError + + +class DynamicClassFromMethodCallback(DynamicClassPluginCallback): + callee: MemberExpr + + def __call__(self, ctx: DynamicClassDefContext) -> None: + self.class_name = ctx.name + self.call_expr = ctx.call + + assert ctx.call.callee is not None + assert isinstance(ctx.call.callee, MemberExpr) + + self.callee = ctx.call.callee + + self.semanal_api = cast(SemanticAnalyzer, ctx.api) + self.create_new_dynamic_class() + + +class ClassDefPluginCallback(SemanalPluginCallback): + reason: Expression + class_defn: ClassDef + ctx: ClassDefContext + + def __call__(self, ctx: ClassDefContext) -> None: + self.reason = ctx.reason + self.class_defn = ctx.cls + self.semanal_api = cast(SemanticAnalyzer, ctx.api) + self.ctx = ctx + self.modify_class_defn() + + @abstractmethod + def modify_class_defn(self) -> None: + raise NotImplementedError + + +class TypeCheckerPluginCallback(DjangoPluginCallback): + type_checker: TypeChecker + + def get_current_module(self) -> MypyFile: + current_module = None + for item in reversed(self.type_checker.scope.stack): + if isinstance(item, MypyFile): + current_module = item + break + assert current_module is not None + return current_module + + def qualified_name(self, name: str) -> str: + return self.type_checker.scope.stack[-1].fullname + '.' + name + + def lookup_typeinfo(self, fullname: str) -> Optional[TypeInfo]: + sym = self.plugin.lookup_fully_qualified(fullname) + if sym is None or sym.node is None: + return None + if not isinstance(sym.node, TypeInfo): + raise ValueError(f'{fullname!r} does not correspond to TypeInfo') + return sym.node + + +class GetMethodCallback(TypeCheckerPluginCallback): + ctx: MethodContext + callee_type: Instance + default_return_type: MypyType + + def __call__(self, ctx: MethodContext) -> MypyType: + self.type_checker = cast(TypeChecker, ctx.api) + self.ctx = ctx + self.callee_type = cast(Instance, ctx.type) + self.default_return_type = self.ctx.default_return_type + return self.get_method_return_type() + + @abstractmethod + def get_method_return_type(self) -> MypyType: + raise NotImplementedError + + +class GetFunctionCallback(TypeCheckerPluginCallback): + ctx: FunctionContext + default_return_type: MypyType + + def __call__(self, ctx: FunctionContext) -> MypyType: + self.type_checker = cast(TypeChecker, ctx.api) + self.ctx = ctx + self.default_return_type = ctx.default_return_type + return self.get_function_return_type() + + @abstractmethod + def get_function_return_type(self) -> MypyType: + raise NotImplementedError + + +class GetAttributeCallback(TypeCheckerPluginCallback): + obj_type: ProperType + default_attr_type: MypyType + error_context: Union[MemberExpr, NameExpr] + name: str + + def __call__(self, ctx: AttributeContext) -> MypyType: + self.ctx = ctx + self.type_checker = cast(TypeChecker, ctx.api) + self.obj_type = ctx.type + self.default_attr_type = ctx.default_attr_type + + if not isinstance(ctx.context, (MemberExpr, NameExpr)): + return self.default_attr_type + self.error_context = ctx.context + self.name = ctx.context.name + + return self.get_attribute_type() + + @abstractmethod + def get_attribute_type(self) -> MypyType: + raise NotImplementedError() + + +def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]: + return model_info.metadata.setdefault('django', {}) + + +def split_symbol_name(fullname: str, all_modules: Dict[str, MypyFile]) -> Optional[Tuple[str, str]]: + if '.' not in fullname: return None - return sym.node + module_name = None + parts = fullname.split('.') + for i in range(len(parts), 0, -1): + possible_module_name = '.'.join(parts[:i]) + if possible_module_name in all_modules: + module_name = possible_module_name + break + if module_name is None: + return None + + symbol_name = fullname.replace(module_name, '').lstrip('.') + return module_name, symbol_name -def lookup_fully_qualified_typeinfo(api: Union[TypeChecker, SemanticAnalyzer], fullname: str) -> Optional[TypeInfo]: - node = lookup_fully_qualified_generic(fullname, api.modules) - if not isinstance(node, TypeInfo): +def lookup_fully_qualified_typeinfo(api: AnyPluginAPI, fullname: str) -> Optional[TypeInfo]: + split = split_symbol_name(fullname, api.modules) + if split is None: + return None + module_name, cls_name = split + + sym_table = api.modules[module_name].names # type: Dict[str, SymbolTableNode] + if '.' in cls_name: + parent_cls_name, _, cls_name = cls_name.rpartition('.') + # nested class + for parent_cls_name in parent_cls_name.split('.'): + sym = sym_table.get(parent_cls_name) + if (sym is None or sym.node is None + or not isinstance(sym.node, TypeInfo)): + return None + sym_table = sym.node.names + + sym = sym_table.get(cls_name) + if (sym is None + or sym.node is None + or not isinstance(sym.node, TypeInfo)): return None - return node + return sym.node -def lookup_class_typeinfo(api: TypeChecker, klass: type) -> Optional[TypeInfo]: +def lookup_class_typeinfo(api: AnyPluginAPI, klass: type) -> Optional[TypeInfo]: fullname = get_class_fullname(klass) field_info = lookup_fully_qualified_typeinfo(api, fullname) return field_info @@ -79,36 +391,6 @@ def get_class_fullname(klass: type) -> str: return klass.__module__ + '.' + klass.__qualname__ -def get_call_argument_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[Expression]: - """ - Return the expression for the specific argument. - This helper should only be used with non-star arguments. - """ - if name not in ctx.callee_arg_names: - return None - idx = ctx.callee_arg_names.index(name) - args = ctx.args[idx] - if len(args) != 1: - # Either an error or no value passed. - return None - return args[0] - - -def get_call_argument_type_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[MypyType]: - """Return the type for the specific argument. - - This helper should only be used with non-star arguments. - """ - if name not in ctx.callee_arg_names: - return None - idx = ctx.callee_arg_names.index(name) - arg_types = ctx.arg_types[idx] - if len(arg_types) != 1: - # Either an error or no value passed. - return None - return arg_types[0] - - def make_optional(typ: MypyType) -> MypyType: return UnionType.make_union([typ, NoneTyp()]) @@ -153,59 +435,10 @@ def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is return AnyType(TypeOfAny.explicit) -def get_field_lookup_exact_type(api: TypeChecker, field: Field) -> MypyType: - if isinstance(field, (RelatedField, ForeignObjectRel)): - lookup_type_class = field.related_model - rel_model_info = lookup_class_typeinfo(api, lookup_type_class) - if rel_model_info is None: - return AnyType(TypeOfAny.from_error) - return make_optional(Instance(rel_model_info, [])) - - field_info = lookup_class_typeinfo(api, field.__class__) - if field_info is None: - return AnyType(TypeOfAny.explicit) - return get_private_descriptor_type(field_info, '_pyi_lookup_exact_type', - is_nullable=field.null) - - -def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]: - metaclass_sym = info.names.get('Meta') - if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo): - return metaclass_sym.node - return None - - -def add_new_class_for_module(module: MypyFile, - name: str, - bases: List[Instance], - fields: Optional[Dict[str, MypyType]] = None - ) -> TypeInfo: - new_class_unique_name = checker.gen_unique_name(name, module.names) - - # make new class expression - classdef = ClassDef(new_class_unique_name, Block([])) - classdef.fullname = module.fullname + '.' + new_class_unique_name - - # make new TypeInfo - new_typeinfo = TypeInfo(SymbolTable(), classdef, module.fullname) - new_typeinfo.bases = bases - calculate_mro(new_typeinfo) - new_typeinfo.calculate_metaclass_type() - - # add fields - if fields: - for field_name, field_type in fields.items(): - var = Var(field_name, type=field_type) - var.info = new_typeinfo - var._fullname = new_typeinfo.fullname + '.' + field_name - new_typeinfo.names[field_name] = SymbolTableNode(MDEF, var, plugin_generated=True) +def get_current_module(api: AnyPluginAPI) -> MypyFile: + if isinstance(api, SemanticAnalyzer): + return api.cur_mod_node - classdef.info = new_typeinfo - module.names[new_class_unique_name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True) - return new_typeinfo - - -def get_current_module(api: TypeChecker) -> MypyFile: current_module = None for item in reversed(api.scope.stack): if isinstance(item, MypyFile): @@ -215,21 +448,6 @@ def get_current_module(api: TypeChecker) -> MypyFile: return current_module -def make_oneoff_named_tuple(api: TypeChecker, name: str, fields: 'OrderedDict[str, MypyType]') -> TupleType: - current_module = get_current_module(api) - namedtuple_info = add_new_class_for_module(current_module, name, - bases=[api.named_generic_type('typing.NamedTuple', [])], - fields=fields) - return TupleType(list(fields.values()), fallback=Instance(namedtuple_info, [])) - - -def make_tuple(api: 'TypeChecker', fields: List[MypyType]) -> TupleType: - # fallback for tuples is any builtins.tuple instance - fallback = api.named_generic_type('builtins.tuple', - [AnyType(TypeOfAny.special_form)]) - return TupleType(fields, fallback=fallback) - - def convert_any_to_type(typ: MypyType, referred_to_type: MypyType) -> MypyType: if isinstance(typ, UnionType): converted_items = [] @@ -252,13 +470,6 @@ def convert_any_to_type(typ: MypyType, referred_to_type: MypyType) -> MypyType: return typ -def make_typeddict(api: CheckerPluginInterface, fields: 'OrderedDict[str, MypyType]', - required_keys: Set[str]) -> TypedDictType: - object_type = api.named_generic_type('mypy_extensions._TypedDict', []) - typed_dict_type = TypedDictType(fields, required_keys=required_keys, fallback=object_type) - return typed_dict_type - - def resolve_string_attribute_value(attr_expr: Expression, django_context: 'DjangoContext') -> Optional[str]: if isinstance(attr_expr, StrExpr): return attr_expr.value @@ -272,113 +483,36 @@ def resolve_string_attribute_value(attr_expr: Expression, django_context: 'Djang return None -def get_semanal_api(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> SemanticAnalyzer: - if not isinstance(ctx.api, SemanticAnalyzer): - raise ValueError('Not a SemanticAnalyzer') - return ctx.api - - -def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionContext]) -> TypeChecker: - if not isinstance(ctx.api, TypeChecker): - raise ValueError('Not a TypeChecker') - return ctx.api - - -def is_model_subclass_info(info: TypeInfo, django_context: 'DjangoContext') -> bool: - return (info.fullname in django_context.all_registered_model_class_fullnames - or info.has_base(fullnames.MODEL_CLASS_FULLNAME)) +def new_typeinfo(name: str, + *, + bases: List[Instance], + module_name: str) -> TypeInfo: + """ + Construct new TypeInfo instance. Cannot be used for nested classes. + """ + class_def = ClassDef(name, Block([])) + class_def.fullname = module_name + '.' + name + info = TypeInfo(SymbolTable(), class_def, module_name) + info.bases = bases + calculate_mro(info) + info.metaclass_type = info.calculate_metaclass_type() -def check_types_compatible(ctx: Union[FunctionContext, MethodContext], - *, expected_type: MypyType, actual_type: MypyType, error_message: str) -> None: - api = get_typechecker_api(ctx) - api.check_subtype(actual_type, expected_type, - ctx.context, error_message, - 'got', 'expected') + class_def.info = info + return info -def add_new_sym_for_info(info: TypeInfo, *, name: str, sym_type: MypyType) -> None: - # type=: type of the variable itself - var = Var(name=name, type=sym_type) - # var.info: type of the object variable is bound to - var.info = info - var._fullname = info.fullname + '.' + name - var.is_initialized_in_class = True - var.is_inferred = True - info.names[name] = SymbolTableNode(MDEF, var, - plugin_generated=True) +def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]: + metaclass_sym = info.names.get('Meta') + if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo): + return metaclass_sym.node + return None def build_unannotated_method_args(method_node: FuncDef) -> Tuple[List[Argument], MypyType]: prepared_arguments = [] - try: - arguments = method_node.arguments[1:] - except AttributeError: - arguments = [] - for argument in arguments: + for argument in method_node.arguments[1:]: argument.type_annotation = AnyType(TypeOfAny.unannotated) prepared_arguments.append(argument) return_type = AnyType(TypeOfAny.unannotated) return prepared_arguments, return_type - - -def copy_method_to_another_class(ctx: ClassDefContext, self_type: Instance, - new_method_name: str, method_node: FuncDef) -> None: - semanal_api = get_semanal_api(ctx) - if method_node.type is None: - if not semanal_api.final_iteration: - semanal_api.defer() - return - - arguments, return_type = build_unannotated_method_args(method_node) - add_method(ctx, - new_method_name, - args=arguments, - return_type=return_type, - self_type=self_type) - return - - method_type = method_node.type - if not isinstance(method_type, CallableType): - if not semanal_api.final_iteration: - semanal_api.defer() - return - - arguments = [] - bound_return_type = semanal_api.anal_type(method_type.ret_type, - allow_placeholder=True) - - assert bound_return_type is not None - - if isinstance(bound_return_type, PlaceholderNode): - return - - for arg_name, arg_type, original_argument in zip(method_type.arg_names[1:], - method_type.arg_types[1:], - method_node.arguments[1:]): - bound_arg_type = semanal_api.anal_type(arg_type, allow_placeholder=True) - if bound_arg_type is None and not semanal_api.final_iteration: - semanal_api.defer() - return - - assert bound_arg_type is not None - - if isinstance(bound_arg_type, PlaceholderNode): - return - - var = Var(name=original_argument.variable.name, - type=arg_type) - var.line = original_argument.variable.line - var.column = original_argument.variable.column - argument = Argument(variable=var, - type_annotation=bound_arg_type, - initializer=original_argument.initializer, - kind=original_argument.kind) - argument.set_line(original_argument) - arguments.append(argument) - - add_method(ctx, - new_method_name, - args=arguments, - return_type=bound_return_type, - self_type=self_type) diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index 3f64ba417..d478ce59a 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -1,5 +1,4 @@ import configparser -from functools import partial from typing import Callable, Dict, List, NoReturn, Optional, Tuple, cast from django.db.models.fields.related import RelatedField @@ -10,44 +9,33 @@ ) from mypy.types import Type as MypyType -import mypy_django_plugin.transformers.orm_lookups from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.lib import fullnames, helpers -from mypy_django_plugin.transformers import ( - fields, forms, init_create, meta, querysets, request, settings, +from mypy_django_plugin.transformers.fields import FieldContructorCallback +from mypy_django_plugin.transformers.forms import ( + FormCallback, GetFormCallback, GetFormClassCallback, +) +from mypy_django_plugin.transformers.init_create import ( + ModelCreateCallback, ModelInitCallback, ) from mypy_django_plugin.transformers.managers import ( - create_new_manager_class_from_from_queryset_method, + ManagerFromQuerySetCallback, +) +from mypy_django_plugin.transformers.meta import MetaGetFieldCallback +from mypy_django_plugin.transformers.models import ModelCallback +from mypy_django_plugin.transformers.orm_lookups import ( + QuerySetFilterTypecheckCallback, +) +from mypy_django_plugin.transformers.querysets import ( + QuerySetValuesCallback, QuerySetValuesListCallback, +) +from mypy_django_plugin.transformers.related_managers import ( + GetRelatedManagerCallback, +) +from mypy_django_plugin.transformers.request import RequestUserModelCallback +from mypy_django_plugin.transformers.settings import ( + GetTypeOfSettingsAttributeCallback, GetUserModelCallback, ) -from mypy_django_plugin.transformers.models import process_model_class - - -def transform_model_class(ctx: ClassDefContext, - django_context: DjangoContext) -> None: - sym = ctx.api.lookup_fully_qualified_or_none(fullnames.MODEL_CLASS_FULLNAME) - - if sym is not None and isinstance(sym.node, TypeInfo): - helpers.get_django_metadata(sym.node)['model_bases'][ctx.cls.fullname] = 1 - else: - if not ctx.api.final_iteration: - ctx.api.defer() - return - - process_model_class(ctx, django_context) - - -def transform_form_class(ctx: ClassDefContext) -> None: - sym = ctx.api.lookup_fully_qualified_or_none(fullnames.BASEFORM_CLASS_FULLNAME) - if sym is not None and isinstance(sym.node, TypeInfo): - helpers.get_django_metadata(sym.node)['baseform_bases'][ctx.cls.fullname] = 1 - - forms.make_meta_nested_class_inherit_from_any(ctx) - - -def add_new_manager_base(ctx: ClassDefContext) -> None: - sym = ctx.api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME) - if sym is not None and isinstance(sym.node, TypeInfo): - helpers.get_django_metadata(sym.node)['manager_bases'][ctx.cls.fullname] = 1 def extract_django_settings_module(config_file_path: Optional[str]) -> str: @@ -136,7 +124,21 @@ def _get_typeinfo_or_none(self, class_name: str) -> Optional[TypeInfo]: def _new_dependency(self, module: str) -> Tuple[int, str, int]: return 10, module, -1 + def _add_new_manager_base(self, fullname: str) -> None: + sym = self.lookup_fully_qualified(fullnames.MANAGER_CLASS_FULLNAME) + if sym is not None and isinstance(sym.node, TypeInfo): + helpers.get_django_metadata(sym.node)['manager_bases'][fullname] = 1 + + def _add_new_form_base(self, fullname: str) -> None: + sym = self.lookup_fully_qualified(fullnames.BASEFORM_CLASS_FULLNAME) + if sym is not None and isinstance(sym.node, TypeInfo): + helpers.get_django_metadata(sym.node)['baseform_bases'][fullname] = 1 + def get_additional_deps(self, file: MypyFile) -> List[Tuple[int, str, int]]: + # load QuerySet and Manager together (for as_manager) + if file.fullname == 'django.db.models.query': + return [self._new_dependency('django.db.models.manager')] + # for settings if file.fullname == 'django.conf' and self.django_context.django_settings_module: return [self._new_dependency(self.django_context.django_settings_module)] @@ -183,19 +185,15 @@ def get_additional_deps(self, file: MypyFile) -> List[Tuple[int, str, int]]: def get_function_hook(self, fullname: str ) -> Optional[Callable[[FunctionContext], MypyType]]: if fullname == 'django.contrib.auth.get_user_model': - return partial(settings.get_user_model_hook, django_context=self.django_context) - - manager_bases = self._get_current_manager_bases() - if fullname in manager_bases: - return querysets.determine_proper_manager_type + return GetUserModelCallback(self) info = self._get_typeinfo_or_none(fullname) if info: if info.has_base(fullnames.FIELD_FULLNAME): - return partial(fields.transform_into_proper_return_type, django_context=self.django_context) + return FieldContructorCallback(self) - if helpers.is_model_subclass_info(info, self.django_context): - return partial(init_create.redefine_and_typecheck_model_init, django_context=self.django_context) + if self.django_context.is_model_subclass(info): + return ModelInitCallback(self) return None def get_method_hook(self, fullname: str @@ -204,59 +202,68 @@ def get_method_hook(self, fullname: str if method_name == 'get_form_class': info = self._get_typeinfo_or_none(class_fullname) if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME): - return forms.extract_proper_type_for_get_form_class + return GetFormClassCallback(self) if method_name == 'get_form': info = self._get_typeinfo_or_none(class_fullname) if info and info.has_base(fullnames.FORM_MIXIN_CLASS_FULLNAME): - return forms.extract_proper_type_for_get_form + return GetFormCallback(self) if method_name == 'values': info = self._get_typeinfo_or_none(class_fullname) if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME): - return partial(querysets.extract_proper_type_queryset_values, django_context=self.django_context) + return QuerySetValuesCallback(self) + # return partial(querysets.extract_proper_type_queryset_values, django_context=self.django_context) if method_name == 'values_list': info = self._get_typeinfo_or_none(class_fullname) if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME): - return partial(querysets.extract_proper_type_queryset_values_list, django_context=self.django_context) + return QuerySetValuesListCallback(self) + # return partial(querysets.extract_proper_type_queryset_values_list, django_context=self.django_context) if method_name == 'get_field': info = self._get_typeinfo_or_none(class_fullname) if info and info.has_base(fullnames.OPTIONS_CLASS_FULLNAME): - return partial(meta.return_proper_field_type_from_get_field, django_context=self.django_context) + return MetaGetFieldCallback(self) manager_classes = self._get_current_manager_bases() if class_fullname in manager_classes and method_name == 'create': - return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context) + return ModelCreateCallback(self) + if class_fullname in manager_classes and method_name in {'filter', 'get', 'exclude'}: - return partial(mypy_django_plugin.transformers.orm_lookups.typecheck_queryset_filter, - django_context=self.django_context) + return QuerySetFilterTypecheckCallback(self) + return None def get_base_class_hook(self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: if (fullname in self.django_context.all_registered_model_class_fullnames or fullname in self._get_current_model_bases()): - return partial(transform_model_class, django_context=self.django_context) + return ModelCallback(self) if fullname in self._get_current_manager_bases(): - return add_new_manager_base + self._add_new_manager_base(fullname) + return None if fullname in self._get_current_form_bases(): - return transform_form_class + self._add_new_form_base(fullname) + return FormCallback(self) + return None def get_attribute_hook(self, fullname: str ) -> Optional[Callable[[AttributeContext], MypyType]]: class_name, _, attr_name = fullname.rpartition('.') if class_name == fullnames.DUMMY_SETTINGS_BASE_CLASS: - return partial(settings.get_type_of_settings_attribute, - django_context=self.django_context) + return GetTypeOfSettingsAttributeCallback(self) info = self._get_typeinfo_or_none(class_name) if info and info.has_base(fullnames.HTTPREQUEST_CLASS_FULLNAME) and attr_name == 'user': - return partial(request.set_auth_user_model_as_type_for_request_user, django_context=self.django_context) + return RequestUserModelCallback(self) + + if info and info.has_base(fullnames.MODEL_CLASS_FULLNAME): + return GetRelatedManagerCallback(self) + return None def get_dynamic_class_hook(self, fullname: str @@ -265,7 +272,7 @@ def get_dynamic_class_hook(self, fullname: str class_name, _, _ = fullname.rpartition('.') info = self._get_typeinfo_or_none(class_name) if info and info.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME): - return create_new_manager_class_from_from_queryset_method + return ManagerFromQuerySetCallback(self) return None diff --git a/mypy_django_plugin/py.typed b/mypy_django_plugin/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py index b88fdbf59..7a1bb2171 100644 --- a/mypy_django_plugin/transformers/fields.py +++ b/mypy_django_plugin/transformers/fields.py @@ -2,42 +2,19 @@ from django.db.models.fields import Field from django.db.models.fields.related import RelatedField +from mypy.checkexpr import FunctionContext from mypy.nodes import AssignmentStmt, NameExpr, TypeInfo -from mypy.plugin import FunctionContext from mypy.types import AnyType, Instance from mypy.types import Type as MypyType from mypy.types import TypeOfAny -from mypy_django_plugin.django.context import DjangoContext -from mypy_django_plugin.lib import fullnames, helpers +from mypy_django_plugin.lib import chk_helpers, fullnames, helpers -def _get_current_field_from_assignment(ctx: FunctionContext, django_context: DjangoContext) -> Optional[Field]: - outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class() - if (outer_model_info is None - or not helpers.is_model_subclass_info(outer_model_info, django_context)): - return None - - field_name = None - for stmt in outer_model_info.defn.defs.body: - if isinstance(stmt, AssignmentStmt): - if stmt.rvalue == ctx.context: - if not isinstance(stmt.lvalues[0], NameExpr): - return None - field_name = stmt.lvalues[0].name - break - if field_name is None: - return None - - model_cls = django_context.get_model_class_by_fullname(outer_model_info.fullname) - if model_cls is None: - return None - - current_field = model_cls._meta.get_field(field_name) - return current_field - - -def reparametrize_related_field_type(related_field_type: Instance, set_type, get_type) -> Instance: +def reparametrize_related_field_type(related_field_type: Instance, + set_type: MypyType, + get_type: MypyType + ) -> Instance: args = [ helpers.convert_any_to_type(related_field_type.args[0], set_type), helpers.convert_any_to_type(related_field_type.args[1], get_type), @@ -45,63 +22,6 @@ def reparametrize_related_field_type(related_field_type: Instance, set_type, get return helpers.reparametrize_instance(related_field_type, new_args=args) -def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType: - current_field = _get_current_field_from_assignment(ctx, django_context) - if current_field is None: - return AnyType(TypeOfAny.from_error) - - assert isinstance(current_field, RelatedField) - - related_model_cls = django_context.get_field_related_model_cls(current_field) - if related_model_cls is None: - return AnyType(TypeOfAny.from_error) - - default_related_field_type = set_descriptor_types_for_field(ctx) - - # self reference with abstract=True on the model where ForeignKey is defined - current_model_cls = current_field.model - if (current_model_cls._meta.abstract - and current_model_cls == related_model_cls): - # for all derived non-abstract classes, set variable with this name to - # __get__/__set__ of ForeignKey of derived model - for model_cls in django_context.all_registered_model_classes: - if issubclass(model_cls, current_model_cls) and not model_cls._meta.abstract: - derived_model_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), model_cls) - if derived_model_info is not None: - fk_ref_type = Instance(derived_model_info, []) - derived_fk_type = reparametrize_related_field_type(default_related_field_type, - set_type=fk_ref_type, get_type=fk_ref_type) - helpers.add_new_sym_for_info(derived_model_info, - name=current_field.name, - sym_type=derived_fk_type) - - related_model = related_model_cls - related_model_to_set = related_model_cls - if related_model_to_set._meta.proxy_for_model is not None: - related_model_to_set = related_model_to_set._meta.proxy_for_model - - typechecker_api = helpers.get_typechecker_api(ctx) - - related_model_info = helpers.lookup_class_typeinfo(typechecker_api, related_model) - if related_model_info is None: - # maybe no type stub - related_model_type = AnyType(TypeOfAny.unannotated) - else: - related_model_type = Instance(related_model_info, []) # type: ignore - - related_model_to_set_info = helpers.lookup_class_typeinfo(typechecker_api, related_model_to_set) - if related_model_to_set_info is None: - # maybe no type stub - related_model_to_set_type = AnyType(TypeOfAny.unannotated) - else: - related_model_to_set_type = Instance(related_model_to_set_info, []) # type: ignore - - # replace Any with referred_to_type - return reparametrize_related_field_type(default_related_field_type, - set_type=related_model_to_set_type, - get_type=related_model_type) - - def get_field_descriptor_types(field_info: TypeInfo, is_nullable: bool) -> Tuple[MypyType, MypyType]: set_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_set_type', is_nullable=is_nullable) @@ -110,11 +30,16 @@ def get_field_descriptor_types(field_info: TypeInfo, is_nullable: bool) -> Tuple return set_type, get_type +def get_field_type(field_info: TypeInfo, is_nullable: bool) -> Instance: + set_type, get_type = get_field_descriptor_types(field_info, is_nullable) + return Instance(field_info, [set_type, get_type]) + + def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance: default_return_type = cast(Instance, ctx.default_return_type) is_nullable = False - null_expr = helpers.get_call_argument_by_name(ctx, 'null') + null_expr = chk_helpers.get_call_argument_by_name(ctx, 'null') if null_expr is not None: is_nullable = helpers.parse_bool(null_expr) or False @@ -122,36 +47,123 @@ def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance: return helpers.reparametrize_instance(default_return_type, [set_type, get_type]) -def determine_type_of_array_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType: - default_return_type = set_descriptor_types_for_field(ctx) - - base_field_arg_type = helpers.get_call_argument_type_by_name(ctx, 'base_field') - if not base_field_arg_type or not isinstance(base_field_arg_type, Instance): - return default_return_type - - base_type = base_field_arg_type.args[1] # extract __get__ type - args = [] - for default_arg in default_return_type.args: - args.append(helpers.convert_any_to_type(default_arg, base_type)) - - return helpers.reparametrize_instance(default_return_type, args) - - -def transform_into_proper_return_type(ctx: FunctionContext, django_context: DjangoContext) -> MypyType: - default_return_type = ctx.default_return_type - assert isinstance(default_return_type, Instance) - - outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class() - if (outer_model_info is None - or not helpers.is_model_subclass_info(outer_model_info, django_context)): - return ctx.default_return_type - - assert isinstance(outer_model_info, TypeInfo) - - if helpers.has_any_of_bases(default_return_type.type, fullnames.RELATED_FIELDS_CLASSES): - return fill_descriptor_types_for_related_field(ctx, django_context) - - if default_return_type.type.has_base(fullnames.ARRAY_FIELD_FULLNAME): - return determine_type_of_array_field(ctx, django_context) - - return set_descriptor_types_for_field(ctx) +class FieldContructorCallback(helpers.GetFunctionCallback): + default_return_type: Instance + + def _get_field_from_model_cls_assignment(self) -> Optional[Field]: + """ Use AssignmentStmt inside model class declaration to find instance of Field (from DjangoContext)""" + outer_model_info = self.type_checker.scope.active_class() + if (outer_model_info is None + or not self.django_context.is_model_subclass(outer_model_info)): + return None + + field_name = None + for stmt in outer_model_info.defn.defs.body: + if isinstance(stmt, AssignmentStmt): + if stmt.rvalue == self.ctx.context: + if not isinstance(stmt.lvalues[0], NameExpr): + return None + field_name = stmt.lvalues[0].name + break + if field_name is None: + return None + + model_cls = self.django_context.get_model_class_by_fullname(outer_model_info.fullname) + if model_cls is None: + return None + + current_field = model_cls._meta.get_field(field_name) + return current_field + + def current_field_type(self) -> Instance: + is_nullable = False + null_expr = chk_helpers.get_call_argument_by_name(self.ctx, 'null') + if null_expr is not None: + is_nullable = helpers.parse_bool(null_expr) or False + + set_type, get_type = get_field_descriptor_types(self.default_return_type.type, is_nullable) + return helpers.reparametrize_instance(self.default_return_type, [set_type, get_type]) + + def array_field_type(self) -> MypyType: + default_array_field_type = self.current_field_type() + + base_field_arg_type = chk_helpers.get_call_argument_type_by_name(self.ctx, 'base_field') + if not base_field_arg_type or not isinstance(base_field_arg_type, Instance): + return default_array_field_type + + base_type = base_field_arg_type.args[1] # extract __get__ type + args = [] + for default_arg in default_array_field_type.args: + args.append(helpers.convert_any_to_type(default_arg, base_type)) + + return helpers.reparametrize_instance(default_array_field_type, args) + + def related_field_type(self) -> MypyType: + current_field = self._get_field_from_model_cls_assignment() + if current_field is None: + return AnyType(TypeOfAny.from_error) + + assert isinstance(current_field, RelatedField) + + related_model_cls = self.django_context.get_field_related_model_cls(current_field) + if related_model_cls is None: + return AnyType(TypeOfAny.from_error) + + default_related_field_type = set_descriptor_types_for_field(self.ctx) + + # self reference with abstract=True on the model where ForeignKey is defined + current_model_cls = current_field.model + if (current_model_cls._meta.abstract + and current_model_cls == related_model_cls): + # for all derived non-abstract classes, set variable with this name to + # __get__/__set__ of ForeignKey of derived model + for model_cls in self.django_context.all_registered_model_classes: + if issubclass(model_cls, current_model_cls) and not model_cls._meta.abstract: + derived_model_info = helpers.lookup_class_typeinfo(self.type_checker, model_cls) + if derived_model_info is not None: + fk_ref_type = Instance(derived_model_info, []) + derived_fk_type = reparametrize_related_field_type(default_related_field_type, + set_type=fk_ref_type, get_type=fk_ref_type) + chk_helpers.add_new_sym_for_info(derived_model_info, + name=current_field.name, + sym_type=derived_fk_type) + + related_model = related_model_cls + related_model_to_set = related_model_cls + if related_model_to_set._meta.proxy_for_model is not None: + related_model_to_set = related_model_to_set._meta.proxy_for_model + + related_model_info = helpers.lookup_class_typeinfo(self.type_checker, related_model) + if related_model_info is None: + # maybe no type stub + related_model_type = AnyType(TypeOfAny.unannotated) + else: + related_model_type = Instance(related_model_info, []) # type: ignore + + related_model_to_set_info = helpers.lookup_class_typeinfo(self.type_checker, related_model_to_set) + if related_model_to_set_info is None: + # maybe no type stub + related_model_to_set_type = AnyType(TypeOfAny.unannotated) + else: + related_model_to_set_type = Instance(related_model_to_set_info, []) # type: ignore + + # replace Any with referred_to_type + return reparametrize_related_field_type(default_related_field_type, + set_type=related_model_to_set_type, + get_type=related_model_type) + + def get_function_return_type(self) -> MypyType: + outer_model_info = self.type_checker.scope.active_class() + if (outer_model_info is None + or not self.django_context.is_model_subclass(outer_model_info)): + return self.default_return_type + + assert isinstance(outer_model_info, TypeInfo) + + if helpers.has_any_of_bases(self.default_return_type.type, fullnames.RELATED_FIELDS_CLASSES): + return self.related_field_type() + + if self.default_return_type.type.has_base(fullnames.ARRAY_FIELD_FULLNAME): + return self.array_field_type() + + return self.current_field_type() diff --git a/mypy_django_plugin/transformers/forms.py b/mypy_django_plugin/transformers/forms.py index 7bd0e1116..4085b72ef 100644 --- a/mypy_django_plugin/transformers/forms.py +++ b/mypy_django_plugin/transformers/forms.py @@ -1,52 +1,47 @@ from typing import Optional -from mypy.plugin import ClassDefContext, MethodContext from mypy.types import CallableType, Instance, NoneTyp from mypy.types import Type as MypyType from mypy.types import TypeType -from mypy_django_plugin.lib import helpers +from mypy_django_plugin.lib import chk_helpers, helpers -def make_meta_nested_class_inherit_from_any(ctx: ClassDefContext) -> None: - meta_node = helpers.get_nested_meta_node_for_current_class(ctx.cls.info) - if meta_node is None: - if not ctx.api.final_iteration: - ctx.api.defer() - else: +class FormCallback(helpers.ClassDefPluginCallback): + def modify_class_defn(self) -> None: + meta_node = helpers.get_nested_meta_node_for_current_class(self.class_defn.info) + if meta_node is None: + return None meta_node.fallback_to_any = True -def get_specified_form_class(object_type: Instance) -> Optional[TypeType]: - form_class_sym = object_type.type.get('form_class') - if form_class_sym and isinstance(form_class_sym.type, CallableType): - return TypeType(form_class_sym.type.ret_type) - return None +class FormMethodCallback(helpers.GetMethodCallback): + def get_specified_form_class(self) -> Optional[TypeType]: + form_class_sym = self.callee_type.type.get('form_class') + if form_class_sym and isinstance(form_class_sym.type, CallableType): + return TypeType(form_class_sym.type.ret_type) + return None -def extract_proper_type_for_get_form(ctx: MethodContext) -> MypyType: - object_type = ctx.type - assert isinstance(object_type, Instance) +class GetFormCallback(FormMethodCallback): + def get_method_return_type(self) -> MypyType: + form_class_type = chk_helpers.get_call_argument_type_by_name(self.ctx, 'form_class') + if form_class_type is None or isinstance(form_class_type, NoneTyp): + form_class_type = self.get_specified_form_class() - form_class_type = helpers.get_call_argument_type_by_name(ctx, 'form_class') - if form_class_type is None or isinstance(form_class_type, NoneTyp): - form_class_type = get_specified_form_class(object_type) + if isinstance(form_class_type, TypeType) and isinstance(form_class_type.item, Instance): + return form_class_type.item - if isinstance(form_class_type, TypeType) and isinstance(form_class_type.item, Instance): - return form_class_type.item + if isinstance(form_class_type, CallableType) and isinstance(form_class_type.ret_type, Instance): + return form_class_type.ret_type - if isinstance(form_class_type, CallableType) and isinstance(form_class_type.ret_type, Instance): - return form_class_type.ret_type + return self.default_return_type - return ctx.default_return_type +class GetFormClassCallback(FormMethodCallback): + def get_method_return_type(self) -> MypyType: + form_class_type = self.get_specified_form_class() + if form_class_type is None: + return self.default_return_type -def extract_proper_type_for_get_form_class(ctx: MethodContext) -> MypyType: - object_type = ctx.type - assert isinstance(object_type, Instance) - - form_class_type = get_specified_form_class(object_type) - if form_class_type is None: - return ctx.default_return_type - - return form_class_type + return form_class_type diff --git a/mypy_django_plugin/transformers/init_create.py b/mypy_django_plugin/transformers/init_create.py index fe0b19ee2..e36f955d4 100644 --- a/mypy_django_plugin/transformers/init_create.py +++ b/mypy_django_plugin/transformers/init_create.py @@ -6,7 +6,7 @@ from mypy.types import Type as MypyType from mypy_django_plugin.django.context import DjangoContext -from mypy_django_plugin.lib import helpers +from mypy_django_plugin.lib import chk_helpers, helpers def get_actual_types(ctx: Union[MethodContext, FunctionContext], @@ -32,7 +32,7 @@ def get_actual_types(ctx: Union[MethodContext, FunctionContext], def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_context: DjangoContext, model_cls: Type[Model], method: str) -> MypyType: - typechecker_api = helpers.get_typechecker_api(ctx) + typechecker_api = chk_helpers.get_typechecker_api(ctx) expected_types = django_context.get_expected_types(typechecker_api, model_cls, method=method) expected_keys = [key for key in expected_types.keys() if key != 'pk'] @@ -42,34 +42,36 @@ def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_co model_cls.__name__), ctx.context) continue - helpers.check_types_compatible(ctx, - expected_type=expected_types[actual_name], - actual_type=actual_type, - error_message='Incompatible type for "{}" of "{}"'.format(actual_name, - model_cls.__name__)) + error_message = 'Incompatible type for "{}" of "{}"'.format(actual_name, model_cls.__name__) + chk_helpers.check_types_compatible(ctx, + expected_type=expected_types[actual_name], + actual_type=actual_type, + error_message=error_message) return ctx.default_return_type -def redefine_and_typecheck_model_init(ctx: FunctionContext, django_context: DjangoContext) -> MypyType: - assert isinstance(ctx.default_return_type, Instance) +class ModelInitCallback(helpers.GetFunctionCallback): + def get_function_return_type(self) -> MypyType: + assert isinstance(self.default_return_type, Instance) - model_fullname = ctx.default_return_type.type.fullname - model_cls = django_context.get_model_class_by_fullname(model_fullname) - if model_cls is None: - return ctx.default_return_type + model_fullname = self.default_return_type.type.fullname + model_cls = self.django_context.get_model_class_by_fullname(model_fullname) + if model_cls is None: + return self.default_return_type - return typecheck_model_method(ctx, django_context, model_cls, '__init__') + return typecheck_model_method(self.ctx, self.django_context, model_cls, '__init__') -def redefine_and_typecheck_model_create(ctx: MethodContext, django_context: DjangoContext) -> MypyType: - if not isinstance(ctx.default_return_type, Instance): - # only work with ctx.default_return_type = model Instance - return ctx.default_return_type +class ModelCreateCallback(helpers.GetMethodCallback): + def get_method_return_type(self) -> MypyType: + if not isinstance(self.default_return_type, Instance): + # only work with ctx.default_return_type = model Instance + return self.default_return_type - model_fullname = ctx.default_return_type.type.fullname - model_cls = django_context.get_model_class_by_fullname(model_fullname) - if model_cls is None: - return ctx.default_return_type + model_fullname = self.default_return_type.type.fullname + model_cls = self.django_context.get_model_class_by_fullname(model_fullname) + if model_cls is None: + return self.default_return_type - return typecheck_model_method(ctx, django_context, model_cls, 'create') + return typecheck_model_method(self.ctx, self.django_context, model_cls, 'create') diff --git a/mypy_django_plugin/transformers/managers.py b/mypy_django_plugin/transformers/managers.py index 88201b439..edac1cce6 100644 --- a/mypy_django_plugin/transformers/managers.py +++ b/mypy_django_plugin/transformers/managers.py @@ -1,77 +1,76 @@ -from mypy.nodes import ( - GDEF, FuncDef, MemberExpr, NameExpr, RefExpr, StrExpr, SymbolTableNode, TypeInfo, -) -from mypy.plugin import ClassDefContext, DynamicClassDefContext -from mypy.types import AnyType, Instance, TypeOfAny +from typing import List, Tuple + +from mypy.nodes import Argument, FuncDef, NameExpr, StrExpr, TypeInfo +from mypy.plugin import ClassDefContext +from mypy.types import AnyType, Instance +from mypy.types import Type as MypyType +from mypy.types import TypeOfAny from mypy_django_plugin.lib import fullnames, helpers -def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefContext) -> None: - semanal_api = helpers.get_semanal_api(ctx) - - callee = ctx.call.callee - assert isinstance(callee, MemberExpr) - assert isinstance(callee.expr, RefExpr) - - base_manager_info = callee.expr.node - if base_manager_info is None: - if not semanal_api.final_iteration: - semanal_api.defer() - return - - assert isinstance(base_manager_info, TypeInfo) - new_manager_info = semanal_api.basic_new_typeinfo(ctx.name, - basetype_or_fallback=Instance(base_manager_info, - [AnyType(TypeOfAny.unannotated)])) - new_manager_info.line = ctx.call.line - new_manager_info.defn.line = ctx.call.line - new_manager_info.metaclass_type = new_manager_info.calculate_metaclass_type() - - current_module = semanal_api.cur_mod_node - current_module.names[ctx.name] = SymbolTableNode(GDEF, new_manager_info, - plugin_generated=True) - passed_queryset = ctx.call.args[0] - assert isinstance(passed_queryset, NameExpr) - - derived_queryset_fullname = passed_queryset.fullname - assert derived_queryset_fullname is not None - - sym = semanal_api.lookup_fully_qualified_or_none(derived_queryset_fullname) - assert sym is not None - if sym.node is None: - if not semanal_api.final_iteration: - semanal_api.defer() - else: +def build_unannotated_method_args(method_node: FuncDef) -> Tuple[List[Argument], MypyType]: + prepared_arguments = [] + for argument in method_node.arguments[1:]: + argument.type_annotation = AnyType(TypeOfAny.unannotated) + prepared_arguments.append(argument) + return_type = AnyType(TypeOfAny.unannotated) + return prepared_arguments, return_type + + +class ManagerFromQuerySetCallback(helpers.DynamicClassFromMethodCallback): + def create_new_dynamic_class(self) -> None: + + base_manager_info = self.callee.expr.node # type: ignore + + if base_manager_info is None and not self.defer_till_next_iteration(reason='base_manager_info is None'): + # what exception should be thrown here? + return + + assert isinstance(base_manager_info, TypeInfo) + + new_manager_info, current_module = self.generate_manager_info_and_module(base_manager_info) + + passed_queryset = self.call_expr.args[0] + assert isinstance(passed_queryset, NameExpr) + + derived_queryset_fullname = passed_queryset.fullname + assert derived_queryset_fullname is not None + + sym = self.semanal_api.lookup_fully_qualified_or_none(derived_queryset_fullname) + assert sym is not None + if sym.node is None and not self.defer_till_next_iteration(reason='sym.node is None'): # inherit from Any to prevent false-positives, if queryset class cannot be resolved new_manager_info.fallback_to_any = True - return - - derived_queryset_info = sym.node - assert isinstance(derived_queryset_info, TypeInfo) - - if len(ctx.call.args) > 1: - expr = ctx.call.args[1] - assert isinstance(expr, StrExpr) - custom_manager_generated_name = expr.value - else: - custom_manager_generated_name = base_manager_info.name + 'From' + derived_queryset_info.name - - custom_manager_generated_fullname = '.'.join(['django.db.models.manager', custom_manager_generated_name]) - if 'from_queryset_managers' not in base_manager_info.metadata: - base_manager_info.metadata['from_queryset_managers'] = {} - base_manager_info.metadata['from_queryset_managers'][custom_manager_generated_fullname] = new_manager_info.fullname - - class_def_context = ClassDefContext(cls=new_manager_info.defn, - reason=ctx.call, api=semanal_api) - self_type = Instance(new_manager_info, []) - # we need to copy all methods in MRO before django.db.models.query.QuerySet - for class_mro_info in derived_queryset_info.mro: - if class_mro_info.fullname == fullnames.QUERYSET_CLASS_FULLNAME: - break - for name, sym in class_mro_info.names.items(): - if isinstance(sym.node, FuncDef): - helpers.copy_method_to_another_class(class_def_context, - self_type, - new_method_name=name, - method_node=sym.node) + return + + derived_queryset_info = sym.node + assert isinstance(derived_queryset_info, TypeInfo) + + if len(self.call_expr.args) > 1: + expr = self.call_expr.args[1] + assert isinstance(expr, StrExpr) + custom_manager_generated_name = expr.value + else: + custom_manager_generated_name = base_manager_info.name + 'From' + derived_queryset_info.name + + custom_manager_generated_fullname = '.'.join(['django.db.models.manager', custom_manager_generated_name]) + if 'from_queryset_managers' not in base_manager_info.metadata: + base_manager_info.metadata['from_queryset_managers'] = {} + base_manager_info.metadata['from_queryset_managers'][ + custom_manager_generated_fullname] = new_manager_info.fullname + class_def_context = ClassDefContext( + cls=new_manager_info.defn, + reason=self.call_expr, api=self.semanal_api) + self_type = Instance(new_manager_info, []) + # we need to copy all methods in MRO before django.db.models.query.QuerySet + for class_mro_info in derived_queryset_info.mro: + if class_mro_info.fullname == fullnames.QUERYSET_CLASS_FULLNAME: + break + for name, sym in class_mro_info.names.items(): + if isinstance(sym.node, FuncDef): + self.copy_method_to_another_class( + class_def_context, + self_type, + new_method_name=name, + method_node=sym.node) diff --git a/mypy_django_plugin/transformers/meta.py b/mypy_django_plugin/transformers/meta.py index 64e6e12fe..03f6e26e4 100644 --- a/mypy_django_plugin/transformers/meta.py +++ b/mypy_django_plugin/transformers/meta.py @@ -1,52 +1,46 @@ from django.core.exceptions import FieldDoesNotExist -from mypy.plugin import MethodContext from mypy.types import AnyType, Instance from mypy.types import Type as MypyType from mypy.types import TypeOfAny -from mypy_django_plugin.django.context import DjangoContext -from mypy_django_plugin.lib import helpers +from mypy_django_plugin.lib import chk_helpers, helpers -def _get_field_instance(ctx: MethodContext, field_fullname: str) -> MypyType: - field_info = helpers.lookup_fully_qualified_typeinfo(helpers.get_typechecker_api(ctx), - field_fullname) - if field_info is None: - return AnyType(TypeOfAny.unannotated) - return Instance(field_info, [AnyType(TypeOfAny.explicit), AnyType(TypeOfAny.explicit)]) +class MetaGetFieldCallback(helpers.GetMethodCallback): + def _get_field_instance(self, field_fullname: str) -> MypyType: + field_info = helpers.lookup_fully_qualified_typeinfo(self.type_checker, field_fullname) + if field_info is None: + return AnyType(TypeOfAny.unannotated) + return Instance(field_info, [AnyType(TypeOfAny.explicit), AnyType(TypeOfAny.explicit)]) + def get_method_return_type(self) -> MypyType: + # bail if list of generic params is empty + if len(self.callee_type.args) == 0: + return self.default_return_type -def return_proper_field_type_from_get_field(ctx: MethodContext, django_context: DjangoContext) -> MypyType: - # Options instance - assert isinstance(ctx.type, Instance) + model_type = self.callee_type.args[0] + if not isinstance(model_type, Instance): + return self.default_return_type - # bail if list of generic params is empty - if len(ctx.type.args) == 0: - return ctx.default_return_type + model_cls = self.django_context.get_model_class_by_fullname(model_type.type.fullname) + if model_cls is None: + return self.default_return_type - model_type = ctx.type.args[0] - if not isinstance(model_type, Instance): - return ctx.default_return_type + field_name_expr = chk_helpers.get_call_argument_by_name(self.ctx, 'field_name') + if field_name_expr is None: + return self.default_return_type - model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname) - if model_cls is None: - return ctx.default_return_type + field_name = helpers.resolve_string_attribute_value(field_name_expr, self.django_context) + if field_name is None: + return self.default_return_type - field_name_expr = helpers.get_call_argument_by_name(ctx, 'field_name') - if field_name_expr is None: - return ctx.default_return_type + try: + field = model_cls._meta.get_field(field_name) + except FieldDoesNotExist as exc: + # if model is abstract, do not raise exception, skip false positives + if not model_cls._meta.abstract: + self.ctx.api.fail(exc.args[0], self.ctx.context) + return AnyType(TypeOfAny.from_error) - field_name = helpers.resolve_string_attribute_value(field_name_expr, django_context) - if field_name is None: - return ctx.default_return_type - - try: - field = model_cls._meta.get_field(field_name) - except FieldDoesNotExist as exc: - # if model is abstract, do not raise exception, skip false positives - if not model_cls._meta.abstract: - ctx.api.fail(exc.args[0], ctx.context) - return AnyType(TypeOfAny.from_error) - - field_fullname = helpers.get_class_fullname(field.__class__) - return _get_field_instance(ctx, field_fullname) + field_fullname = helpers.get_class_fullname(field.__class__) + return self._get_field_instance(field_fullname) diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index f0c436ca0..10a53dc66 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -1,189 +1,137 @@ -from typing import Dict, List, Optional, Type, cast +from abc import abstractmethod +from typing import Dict, Optional, Type +from django.db import models from django.db.models.base import Model from django.db.models.fields import DateField, DateTimeField -from django.db.models.fields.related import ForeignKey -from django.db.models.fields.reverse_related import ( - ManyToManyRel, ManyToOneRel, OneToOneRel, +from django.db.models.fields.related import ForeignKey, OneToOneField +from mypy.nodes import ( + ARG_STAR2, MDEF, Argument, FuncDef, SymbolTableNode, TypeInfo, Var, ) -from mypy.nodes import ARG_STAR2, Argument, Context, FuncDef, TypeInfo, Var from mypy.plugin import ClassDefContext from mypy.plugins import common -from mypy.semanal import SemanticAnalyzer +from mypy.semanal import dummy_context from mypy.types import AnyType, Instance from mypy.types import Type as MypyType from mypy.types import TypeOfAny -from mypy_django_plugin.django.context import DjangoContext -from mypy_django_plugin.lib import fullnames, helpers -from mypy_django_plugin.transformers import fields -from mypy_django_plugin.transformers.fields import get_field_descriptor_types - - -class ModelClassInitializer: - api: SemanticAnalyzer - - def __init__(self, ctx: ClassDefContext, django_context: DjangoContext): - self.api = cast(SemanticAnalyzer, ctx.api) - self.model_classdef = ctx.cls - self.django_context = django_context - self.ctx = ctx - - def lookup_typeinfo(self, fullname: str) -> Optional[TypeInfo]: - return helpers.lookup_fully_qualified_typeinfo(self.api, fullname) - - def lookup_typeinfo_or_incomplete_defn_error(self, fullname: str) -> TypeInfo: - info = self.lookup_typeinfo(fullname) - if info is None: - raise helpers.IncompleteDefnException(f'No {fullname!r} found') - return info - - def lookup_class_typeinfo_or_incomplete_defn_error(self, klass: type) -> TypeInfo: - fullname = helpers.get_class_fullname(klass) - field_info = self.lookup_typeinfo_or_incomplete_defn_error(fullname) - return field_info - - def create_new_var(self, name: str, typ: MypyType) -> Var: - # type=: type of the variable itself - var = Var(name=name, type=typ) - # var.info: type of the object variable is bound to - var.info = self.model_classdef.info - var._fullname = self.model_classdef.info.fullname + '.' + name - var.is_initialized_in_class = True - var.is_inferred = True - return var - - def add_new_node_to_model_class(self, name: str, typ: MypyType) -> None: - helpers.add_new_sym_for_info(self.model_classdef.info, - name=name, - sym_type=typ) - - def add_new_class_for_current_module(self, name: str, bases: List[Instance]) -> TypeInfo: - current_module = self.api.modules[self.model_classdef.info.module_name] - new_class_info = helpers.add_new_class_for_module(current_module, - name=name, bases=bases) - return new_class_info - - def run(self) -> None: - model_cls = self.django_context.get_model_class_by_fullname(self.model_classdef.fullname) +from mypy_django_plugin.lib import chk_helpers, fullnames, helpers +from mypy_django_plugin.transformers import fields, new_helpers + + +class TransformModelClassCallback(helpers.ClassDefPluginCallback): + def get_real_manager_fullname(self, manager_fullname: str) -> str: + model_info = self.lookup_typeinfo_or_defer(fullnames.MODEL_CLASS_FULLNAME) + assert model_info is not None + real_manager_fullname = model_info.metadata.get('managers', {}).get(manager_fullname, manager_fullname) + return real_manager_fullname + + def modify_class_defn(self) -> None: + model_cls = self.django_context.get_model_class_by_fullname(self.class_defn.fullname) if model_cls is None: - return - self.run_with_model_cls(model_cls) + return None + return self.modify_model_class_defn(model_cls) - def run_with_model_cls(self, model_cls): - pass + def add_new_model_attribute(self, name: str, typ: MypyType, force_replace: bool = False) -> None: + model_info = self.class_defn.info + if name in model_info.names and not force_replace: + raise ValueError('Attribute already exists on the model') + var = Var(name, type=typ) + var.info = model_info + var._fullname = self.semanal_api.qualified_name(name) + var.is_initialized_in_class = True -class InjectAnyAsBaseForNestedMeta(ModelClassInitializer): - """ - Replaces - class MyModel(models.Model): - class Meta: - pass - with - class MyModel(models.Model): - class Meta(Any): - pass - to get around incompatible Meta inner classes for different models. - """ + sym = SymbolTableNode(MDEF, var, plugin_generated=True) + error_context = None if force_replace else dummy_context() + added = self.semanal_api.add_symbol_table_node(name, sym, context=error_context) + assert added - def run(self) -> None: - meta_node = helpers.get_nested_meta_node_for_current_class(self.model_classdef.info) - if meta_node is None: - return None - meta_node.fallback_to_any = True + def lookup_typeinfo_for_class_or_defer(self, klass: type, *, + reason_for_defer: Optional[str] = None) -> Optional[TypeInfo]: + manager_cls_fullname = helpers.get_class_fullname(klass) + return self.lookup_typeinfo_or_defer(manager_cls_fullname, + reason_for_defer=reason_for_defer) + @abstractmethod + def modify_model_class_defn(self, runtime_model_cls: Type[Model]) -> None: + raise NotImplementedError -class AddDefaultPrimaryKey(ModelClassInitializer): - def run_with_model_cls(self, model_cls: Type[Model]) -> None: - auto_field = model_cls._meta.auto_field - if auto_field and not self.model_classdef.info.has_readable_member(auto_field.attname): - # autogenerated field - auto_field_fullname = helpers.get_class_fullname(auto_field.__class__) - auto_field_info = self.lookup_typeinfo_or_incomplete_defn_error(auto_field_fullname) - - set_type, get_type = fields.get_field_descriptor_types(auto_field_info, is_nullable=False) - self.add_new_node_to_model_class(auto_field.attname, Instance(auto_field_info, - [set_type, get_type])) - - -class AddRelatedModelsId(ModelClassInitializer): - def run_with_model_cls(self, model_cls: Type[Model]) -> None: - for field in model_cls._meta.get_fields(): - if isinstance(field, ForeignKey): - related_model_cls = self.django_context.get_field_related_model_cls(field) - if related_model_cls is None: - error_context: Context = self.ctx.cls - field_sym = self.ctx.cls.info.get(field.name) - if field_sym is not None and field_sym.node is not None: - error_context = field_sym.node - self.api.fail(f'Cannot find model {field.related_model!r} ' - f'referenced in field {field.name!r} ', - ctx=error_context) - self.add_new_node_to_model_class(field.attname, - AnyType(TypeOfAny.explicit)) - continue - if related_model_cls._meta.abstract: - continue +class AddDefaultManagerCallback(TransformModelClassCallback): + def modify_model_class_defn(self, runtime_model_cls: Type[Model]) -> None: + if ('_default_manager' in self.class_defn.info.names + or runtime_model_cls._meta.default_manager is None): + return None - rel_primary_key_field = self.django_context.get_primary_key_field(related_model_cls) - try: - field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_primary_key_field.__class__) - except helpers.IncompleteDefnException as exc: - if not self.api.final_iteration: - raise exc - else: - continue + runtime_default_manager_class = runtime_model_cls._meta.default_manager.__class__ + runtime_manager_cls_fullname = new_helpers.get_class_fullname(runtime_default_manager_class) + manager_cls_fullname = self.get_real_manager_fullname(runtime_manager_cls_fullname) - is_nullable = self.django_context.get_field_nullability(field, None) - set_type, get_type = get_field_descriptor_types(field_info, is_nullable) - self.add_new_node_to_model_class(field.attname, - Instance(field_info, [set_type, get_type])) + try: + default_manager_info = self.lookup_typeinfo_or_defer(manager_cls_fullname) + except new_helpers.TypeInfoNotFound: + default_manager_info = None + + if default_manager_info is None: + if getattr(runtime_model_cls._meta.default_manager, '_built_with_as_manager', False): + # it's a Model.as_manager() class and will cause TypeNotFound exception without proper support + # fallback to Any for now to avoid false positives + self.add_new_model_attribute('_default_manager', AnyType(TypeOfAny.implementation_artifact)) + return + self.add_new_model_attribute('_default_manager', + Instance(default_manager_info, [Instance(self.class_defn.info, [])])) -class AddManagers(ModelClassInitializer): - def has_any_parametrized_manager_as_base(self, info: TypeInfo) -> bool: - for base in helpers.iter_bases(info): - if self.is_any_parametrized_manager(base): - return True - return False + +class AddManagersCallback(TransformModelClassCallback): def is_any_parametrized_manager(self, typ: Instance) -> bool: return typ.type.fullname in fullnames.MANAGER_CLASSES and isinstance(typ.args[0], AnyType) def get_generated_manager_mappings(self, base_manager_fullname: str) -> Dict[str, str]: - base_manager_info = self.lookup_typeinfo(base_manager_fullname) + try: + base_manager_info = self.lookup_typeinfo_or_defer(base_manager_fullname) + except new_helpers.TypeInfoNotFound: + base_manager_info = None + if (base_manager_info is None or 'from_queryset_managers' not in base_manager_info.metadata): return {} return base_manager_info.metadata['from_queryset_managers'] + def has_any_parametrized_manager_as_base(self, info: TypeInfo) -> bool: + for base in helpers.iter_bases(info): + if self.is_any_parametrized_manager(base): + return True + return False + def create_new_model_parametrized_manager(self, name: str, base_manager_info: TypeInfo) -> Instance: bases = [] for original_base in base_manager_info.bases: if self.is_any_parametrized_manager(original_base): if original_base.type is None: - raise helpers.IncompleteDefnException() - + raise new_helpers.TypeInfoNotFound() original_base = helpers.reparametrize_instance(original_base, - [Instance(self.model_classdef.info, [])]) + [Instance(self.class_defn.info, [])]) bases.append(original_base) - new_manager_info = self.add_new_class_for_current_module(name, bases) + current_module = self.semanal_api.modules[self.class_defn.info.module_name] + new_manager_info = chk_helpers.add_new_class_for_current_module(current_module, name, bases) + # copy fields to a new manager new_cls_def_context = ClassDefContext(cls=new_manager_info.defn, reason=self.ctx.reason, - api=self.api) - custom_manager_type = Instance(new_manager_info, [Instance(self.model_classdef.info, [])]) + api=self.semanal_api) + custom_manager_type = Instance(new_manager_info, [Instance(self.class_defn.info, [])]) for name, sym in base_manager_info.names.items(): # replace self type with new class, if copying method if isinstance(sym.node, FuncDef): - helpers.copy_method_to_another_class(new_cls_def_context, - self_type=custom_manager_type, - new_method_name=name, - method_node=sym.node) + self.copy_method_to_another_class(new_cls_def_context, + self_type=custom_manager_type, + new_method_name=name, + method_node=sym.node) continue new_sym = sym.copy() @@ -193,117 +141,163 @@ def create_new_model_parametrized_manager(self, name: str, base_manager_info: Ty new_var._fullname = new_manager_info.fullname + '.' + name new_sym.node = new_var new_manager_info.names[name] = new_sym - return custom_manager_type - def run_with_model_cls(self, model_cls: Type[Model]) -> None: - for manager_name, manager in model_cls._meta.managers_map.items(): + def modify_model_class_defn(self, runtime_model_cls: Type[models.Model]) -> None: + for manager_name, manager in runtime_model_cls._meta.managers_map.items(): manager_class_name = manager.__class__.__name__ - manager_fullname = helpers.get_class_fullname(manager.__class__) try: - manager_info = self.lookup_typeinfo_or_incomplete_defn_error(manager_fullname) - except helpers.IncompleteDefnException as exc: - if not self.api.final_iteration: - raise exc - else: - base_manager_fullname = helpers.get_class_fullname(manager.__class__.__bases__[0]) - generated_managers = self.get_generated_manager_mappings(base_manager_fullname) - if manager_fullname not in generated_managers: - # not a generated manager, continue with the loop - continue - real_manager_fullname = generated_managers[manager_fullname] - manager_info = self.lookup_typeinfo(real_manager_fullname) # type: ignore - if manager_info is None: - continue - manager_class_name = real_manager_fullname.rsplit('.', maxsplit=1)[1] - - if manager_name not in self.model_classdef.info.names: - manager_type = Instance(manager_info, [Instance(self.model_classdef.info, [])]) - self.add_new_node_to_model_class(manager_name, manager_type) + manager_info = self.lookup_typeinfo_for_class_or_defer(manager.__class__) + if manager_info is None: + continue + except new_helpers.TypeInfoNotFound: + manager_info = None + + # creating custom manager class only if there's none in lookup and it's final iteration + if manager_info is None: + manager_fullname = helpers.get_class_fullname(manager.__class__) + base_manager_fullname = helpers.get_class_fullname(manager.__class__.__bases__[0]) + generated_managers = self.get_generated_manager_mappings(base_manager_fullname) + + if manager_fullname not in generated_managers: + continue + + real_manager_fullname = generated_managers[manager_fullname] + try: + manager_info = self.lookup_typeinfo_or_defer(real_manager_fullname) + except new_helpers.TypeInfoNotFound: + manager_info = None + + if manager_info is None: + continue + + manager_class_name = real_manager_fullname.rsplit('.', maxsplit=1)[1] + + if manager_name not in self.class_defn.info.names: + manager_type = Instance(manager_info, [Instance(self.class_defn.info, [])]) + chk_helpers.add_new_sym_for_info(self.class_defn.info, + name=manager_name, + sym_type=manager_type) else: - # creates new MODELNAME_MANAGERCLASSNAME class that represents manager parametrized with current model if not self.has_any_parametrized_manager_as_base(manager_info): continue - custom_model_manager_name = manager.model.__name__ + '_' + manager_class_name try: custom_manager_type = self.create_new_model_parametrized_manager(custom_model_manager_name, base_manager_info=manager_info) - except helpers.IncompleteDefnException: + except new_helpers.TypeInfoNotFound: continue - self.add_new_node_to_model_class(manager_name, custom_manager_type) + chk_helpers.add_new_sym_for_info(self.class_defn.info, name=manager_name, sym_type=custom_manager_type) -class AddDefaultManagerAttribute(ModelClassInitializer): - def run_with_model_cls(self, model_cls: Type[Model]) -> None: - # add _default_manager - if '_default_manager' not in self.model_classdef.info.names: - default_manager_fullname = helpers.get_class_fullname(model_cls._meta.default_manager.__class__) - default_manager_info = self.lookup_typeinfo_or_incomplete_defn_error(default_manager_fullname) - default_manager = Instance(default_manager_info, [Instance(self.model_classdef.info, [])]) - self.add_new_node_to_model_class('_default_manager', default_manager) +class AddPrimaryKeyIfDoesNotExist(TransformModelClassCallback): + """ + Adds default primary key to models which does not define their own. + class User(models.Model): + name = models.TextField() + """ + + def modify_model_class_defn(self, runtime_model_cls: Type[Model]) -> None: + auto_pk_field = runtime_model_cls._meta.auto_field + if auto_pk_field is None: + # defined explicitly + return None + auto_pk_field_name = auto_pk_field.attname + if auto_pk_field_name in self.class_defn.info.names: + # added on previous iteration + return None + + auto_pk_field_info = self.lookup_typeinfo_for_class_or_defer(auto_pk_field.__class__) + if auto_pk_field_info is None: + return None + self.add_new_model_attribute(auto_pk_field_name, + fields.get_field_type(auto_pk_field_info, is_nullable=False)) -class AddRelatedManagers(ModelClassInitializer): - def run_with_model_cls(self, model_cls: Type[Model]) -> None: - # add related managers - for relation in self.django_context.get_model_relations(model_cls): - attname = relation.get_accessor_name() - if attname is None: - # no reverse accessor + +class AddRelatedManagersCallback(TransformModelClassCallback): + def modify_model_class_defn(self, runtime_model_cls: Type[Model]) -> None: + for reverse_manager_name, relation in self.django_context.get_model_relations(runtime_model_cls): + if (reverse_manager_name is None + or reverse_manager_name in self.class_defn.info.names): continue - related_model_cls = self.django_context.get_field_related_model_cls(relation) - if related_model_cls is None: + self.add_new_model_attribute(reverse_manager_name, AnyType(TypeOfAny.implementation_artifact)) + + +class AddForeignPrimaryKeys(TransformModelClassCallback): + def modify_model_class_defn(self, runtime_model_cls: Type[Model]) -> None: + for field in runtime_model_cls._meta.get_fields(): + if not isinstance(field, (OneToOneField, ForeignKey)): + continue + rel_pk_field_name = field.attname + if rel_pk_field_name in self.class_defn.info.names: continue - try: - related_model_info = self.lookup_class_typeinfo_or_incomplete_defn_error(related_model_cls) - except helpers.IncompleteDefnException as exc: - if not self.api.final_iteration: - raise exc + related_model_cls = self.django_context.get_field_related_model_cls(field) + if related_model_cls is None: + field_sym = self.class_defn.info.get(field.name) + if field_sym is not None and field_sym.node is not None: + error_context = field_sym.node else: - continue + error_context = self.class_defn # type: ignore + self.semanal_api.fail(f'Cannot find model {field.related_model!r} ' + f'referenced in field {field.name!r} ', + ctx=error_context) + self.add_new_model_attribute(rel_pk_field_name, AnyType(TypeOfAny.from_error)) + continue + if related_model_cls._meta.abstract: + continue - if isinstance(relation, OneToOneRel): - self.add_new_node_to_model_class(attname, Instance(related_model_info, [])) + rel_pk_field = self.django_context.get_primary_key_field(related_model_cls) + rel_pk_field_info = self.lookup_typeinfo_for_class_or_defer(rel_pk_field.__class__) + if rel_pk_field_info is None: continue - if isinstance(relation, (ManyToOneRel, ManyToManyRel)): - try: - related_manager_info = self.lookup_typeinfo_or_incomplete_defn_error(fullnames.RELATED_MANAGER_CLASS) # noqa: E501 - if 'objects' not in related_model_info.names: - raise helpers.IncompleteDefnException() - except helpers.IncompleteDefnException as exc: - if not self.api.final_iteration: - raise exc - else: - continue - - # create new RelatedManager subclass - parametrized_related_manager_type = Instance(related_manager_info, - [Instance(related_model_info, [])]) - default_manager_type = related_model_info.names['objects'].type - if (default_manager_type is None - or not isinstance(default_manager_type, Instance) - or default_manager_type.type.fullname == fullnames.MANAGER_CLASS_FULLNAME): - self.add_new_node_to_model_class(attname, parametrized_related_manager_type) - continue + field_type = fields.get_field_type(rel_pk_field_info, + is_nullable=self.django_context.get_field_nullability(field)) + self.add_new_model_attribute(rel_pk_field_name, field_type) - name = related_model_cls.__name__ + '_' + 'RelatedManager' - bases = [parametrized_related_manager_type, default_manager_type] - new_related_manager_info = self.add_new_class_for_current_module(name, bases) - self.add_new_node_to_model_class(attname, Instance(new_related_manager_info, [])) +class InjectAnyAsBaseForNestedMeta(TransformModelClassCallback): + """ + Replaces + class MyModel(models.Model): + class Meta: + pass + with + class MyModel(models.Model): + class Meta(Any): + pass + to get around incompatible Meta inner classes for different models. + """ + + def modify_class_defn(self) -> None: + meta_node = helpers.get_nested_meta_node_for_current_class(self.class_defn.info) + if meta_node is None: + return None + meta_node.fallback_to_any = True + + +class AddMetaOptionsAttribute(TransformModelClassCallback): + def modify_model_class_defn(self, runtime_model_cls: Type[Model]) -> None: + if '_meta' not in self.class_defn.info.names: + options_info = self.lookup_typeinfo_or_defer(fullnames.OPTIONS_CLASS_FULLNAME) + if options_info is not None: + self.add_new_model_attribute('_meta', + Instance(options_info, [ + Instance(self.class_defn.info, []) + ])) -class AddExtraFieldMethods(ModelClassInitializer): - def run_with_model_cls(self, model_cls: Type[Model]) -> None: +class AddExtraFieldMethods(TransformModelClassCallback): + def modify_model_class_defn(self, runtime_model_cls: Type[Model]) -> None: # get_FOO_display for choices - for field in self.django_context.get_model_fields(model_cls): + for field in self.django_context.get_model_fields(runtime_model_cls): if field.choices: - info = self.lookup_typeinfo_or_incomplete_defn_error('builtins.str') + info = self.lookup_typeinfo_or_defer('builtins.str') + assert info is not None return_type = Instance(info, []) common.add_method(self.ctx, name='get_{}_display'.format(field.attname), @@ -311,9 +305,9 @@ def run_with_model_cls(self, model_cls: Type[Model]) -> None: return_type=return_type) # get_next_by, get_previous_by for Date, DateTime - for field in self.django_context.get_model_fields(model_cls): + for field in self.django_context.get_model_fields(runtime_model_cls): if isinstance(field, (DateField, DateTimeField)) and not field.null: - return_type = Instance(self.model_classdef.info, []) + return_type = Instance(self.class_defn.info, []) common.add_method(self.ctx, name='get_next_by_{}'.format(field.attname), args=[Argument(Var('kwargs', AnyType(TypeOfAny.explicit)), @@ -330,31 +324,21 @@ def run_with_model_cls(self, model_cls: Type[Model]) -> None: return_type=return_type) -class AddMetaOptionsAttribute(ModelClassInitializer): - def run_with_model_cls(self, model_cls: Type[Model]) -> None: - if '_meta' not in self.model_classdef.info.names: - options_info = self.lookup_typeinfo_or_incomplete_defn_error(fullnames.OPTIONS_CLASS_FULLNAME) - self.add_new_node_to_model_class('_meta', - Instance(options_info, [ - Instance(self.model_classdef.info, []) - ])) - - -def process_model_class(ctx: ClassDefContext, - django_context: DjangoContext) -> None: - initializers = [ - InjectAnyAsBaseForNestedMeta, - AddDefaultPrimaryKey, - AddRelatedModelsId, - AddManagers, - AddDefaultManagerAttribute, - AddRelatedManagers, - AddExtraFieldMethods, - AddMetaOptionsAttribute, - ] - for initializer_cls in initializers: - try: - initializer_cls(ctx, django_context).run() - except helpers.IncompleteDefnException: - if not ctx.api.final_iteration: - ctx.api.defer() +class ModelCallback(helpers.ClassDefPluginCallback): + def __call__(self, ctx: ClassDefContext) -> None: + callback_classes = [ + AddManagersCallback, + AddPrimaryKeyIfDoesNotExist, + AddForeignPrimaryKeys, + AddDefaultManagerCallback, + AddRelatedManagersCallback, + InjectAnyAsBaseForNestedMeta, + AddMetaOptionsAttribute, + AddExtraFieldMethods, + ] + for callback_cls in callback_classes: + callback = callback_cls(self.plugin) # type: ignore + callback.__call__(ctx) + + def modify_class_defn(self) -> None: + raise NotImplementedError() diff --git a/mypy_django_plugin/transformers/new_helpers.py b/mypy_django_plugin/transformers/new_helpers.py new file mode 100644 index 000000000..0924a7b8e --- /dev/null +++ b/mypy_django_plugin/transformers/new_helpers.py @@ -0,0 +1,30 @@ +from typing import Union + +from mypy.nodes import MypyFile, TypeInfo + + +class IncompleteDefnError(Exception): + pass + + +class TypeInfoNotFound(IncompleteDefnError): + def __init__(self, fullname: str) -> None: + super().__init__(f'It is final iteration and required type {fullname!r} is not ready yet.') + + +class AttributeNotFound(IncompleteDefnError): + def __init__(self, node: Union[TypeInfo, MypyFile], attrname: str) -> None: + super().__init__(f'Attribute {attrname!r} is not defined for the {node.fullname!r}.') + + +class NameNotFound(IncompleteDefnError): + def __init__(self, name: str) -> None: + super().__init__(f'Could not find {name!r} in the current activated namespaces') + + +class SymbolAdditionNotPossible(Exception): + pass + + +def get_class_fullname(klass: type) -> str: + return klass.__module__ + '.' + klass.__qualname__ diff --git a/mypy_django_plugin/transformers/orm_lookups.py b/mypy_django_plugin/transformers/orm_lookups.py index 0aa516be0..5eb0e96a5 100644 --- a/mypy_django_plugin/transformers/orm_lookups.py +++ b/mypy_django_plugin/transformers/orm_lookups.py @@ -1,51 +1,47 @@ -from mypy.plugin import MethodContext from mypy.types import AnyType, Instance from mypy.types import Type as MypyType from mypy.types import TypeOfAny -from mypy_django_plugin.django.context import DjangoContext -from mypy_django_plugin.lib import fullnames, helpers +from mypy_django_plugin.lib import chk_helpers, fullnames, helpers -def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext) -> MypyType: - lookup_kwargs = ctx.arg_names[1] - provided_lookup_types = ctx.arg_types[1] +class QuerySetFilterTypecheckCallback(helpers.GetMethodCallback): + def resolve_combinable_type(self, combinable_type: Instance) -> MypyType: + if combinable_type.type.fullname != fullnames.F_EXPRESSION_FULLNAME: + # Combinables aside from F expressions are unsupported + return AnyType(TypeOfAny.explicit) - assert isinstance(ctx.type, Instance) + return self.django_context.resolve_f_expression_type(combinable_type) - if not ctx.type.args or not isinstance(ctx.type.args[0], Instance): - return ctx.default_return_type + def get_method_return_type(self) -> MypyType: + lookup_kwargs = self.ctx.arg_names[1] + provided_lookup_types = self.ctx.arg_types[1] - model_cls_fullname = ctx.type.args[0].type.fullname - model_cls = django_context.get_model_class_by_fullname(model_cls_fullname) - if model_cls is None: - return ctx.default_return_type + if not self.callee_type.args or not isinstance(self.callee_type.args[0], Instance): + return self.default_return_type - for lookup_kwarg, provided_type in zip(lookup_kwargs, provided_lookup_types): - if lookup_kwarg is None: - continue - if (isinstance(provided_type, Instance) - and provided_type.type.has_base('django.db.models.expressions.Combinable')): - provided_type = resolve_combinable_type(provided_type, django_context) + model_cls_fullname = self.callee_type.args[0].type.fullname + model_cls = self.django_context.get_model_class_by_fullname(model_cls_fullname) + if model_cls is None: + return self.default_return_type - lookup_type = django_context.resolve_lookup_expected_type(ctx, model_cls, lookup_kwarg) - # Managers as provided_type is not supported yet - if (isinstance(provided_type, Instance) - and helpers.has_any_of_bases(provided_type.type, (fullnames.MANAGER_CLASS_FULLNAME, - fullnames.QUERYSET_CLASS_FULLNAME))): - return ctx.default_return_type + for lookup_kwarg, provided_type in zip(lookup_kwargs, provided_lookup_types): + if lookup_kwarg is None: + continue + if (isinstance(provided_type, Instance) + and provided_type.type.has_base('django.db.models.expressions.Combinable')): + provided_type = self.resolve_combinable_type(provided_type) - helpers.check_types_compatible(ctx, - expected_type=lookup_type, - actual_type=provided_type, - error_message=f'Incompatible type for lookup {lookup_kwarg!r}:') + lookup_type = self.django_context.resolve_lookup_expected_type(self.ctx, model_cls, lookup_kwarg) + # Managers as provided_type is not supported yet + if (isinstance(provided_type, Instance) + and helpers.has_any_of_bases(provided_type.type, (fullnames.MANAGER_CLASS_FULLNAME, + fullnames.QUERYSET_CLASS_FULLNAME))): + return self.default_return_type - return ctx.default_return_type + chk_helpers.check_types_compatible(self.ctx, + expected_type=lookup_type, + actual_type=provided_type, + error_message=f'Incompatible type for lookup {lookup_kwarg!r}:') - -def resolve_combinable_type(combinable_type: Instance, django_context: DjangoContext) -> MypyType: - if combinable_type.type.fullname != fullnames.F_EXPRESSION_FULLNAME: - # Combinables aside from F expressions are unsupported - return AnyType(TypeOfAny.explicit) - - return django_context.resolve_f_expression_type(combinable_type) + return self.default_return_type diff --git a/mypy_django_plugin/transformers/querysets.py b/mypy_django_plugin/transformers/querysets.py index c157bb4a0..bb7f585cf 100644 --- a/mypy_django_plugin/transformers/querysets.py +++ b/mypy_django_plugin/transformers/querysets.py @@ -1,4 +1,4 @@ -from collections import OrderedDict +import collections from typing import List, Optional, Sequence, Type from django.core.exceptions import FieldError @@ -6,7 +6,7 @@ from django.db.models.fields.related import RelatedField from django.db.models.fields.reverse_related import ForeignObjectRel from mypy.nodes import Expression, NameExpr -from mypy.plugin import FunctionContext, MethodContext +from mypy.plugin import MethodContext from mypy.types import AnyType, Instance from mypy.types import Type as MypyType from mypy.types import TypeOfAny @@ -14,28 +14,7 @@ from mypy_django_plugin.django.context import ( DjangoContext, LookupsAreUnsupported, ) -from mypy_django_plugin.lib import fullnames, helpers - - -def _extract_model_type_from_queryset(queryset_type: Instance) -> Optional[Instance]: - for base_type in [queryset_type, *queryset_type.type.bases]: - if (len(base_type.args) - and isinstance(base_type.args[0], Instance) - and base_type.args[0].type.has_base(fullnames.MODEL_CLASS_FULLNAME)): - return base_type.args[0] - return None - - -def determine_proper_manager_type(ctx: FunctionContext) -> MypyType: - default_return_type = ctx.default_return_type - assert isinstance(default_return_type, Instance) - - outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class() - if (outer_model_info is None - or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME)): - return default_return_type - - return helpers.reparametrize_instance(default_return_type, [Instance(outer_model_info, [])]) +from mypy_django_plugin.lib import chk_helpers, fullnames, helpers def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model], @@ -55,18 +34,28 @@ def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext return AnyType(TypeOfAny.from_error) lookup_field = django_context.get_primary_key_field(related_model_cls) - field_get_type = django_context.get_field_get_type(helpers.get_typechecker_api(ctx), + field_get_type = django_context.get_field_get_type(chk_helpers.get_typechecker_api(ctx), lookup_field, method=method) return field_get_type +def resolve_field_lookups(lookup_exprs: Sequence[Expression], django_context: DjangoContext) -> Optional[List[str]]: + field_lookups = [] + for field_lookup_expr in lookup_exprs: + field_lookup = helpers.resolve_string_attribute_value(field_lookup_expr, django_context) + if field_lookup is None: + return None + field_lookups.append(field_lookup) + return field_lookups + + def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model], flat: bool, named: bool) -> MypyType: field_lookups = resolve_field_lookups(ctx.args[0], django_context) if field_lookups is None: return AnyType(TypeOfAny.from_error) - typechecker_api = helpers.get_typechecker_api(ctx) + typechecker_api = chk_helpers.get_typechecker_api(ctx) if len(field_lookups) == 0: if flat: primary_key_field = django_context.get_primary_key_field(model_cls) @@ -75,12 +64,12 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, assert lookup_type is not None return lookup_type elif named: - column_types: 'OrderedDict[str, MypyType]' = OrderedDict() + column_types = collections.OrderedDict() for field in django_context.get_model_fields(model_cls): column_type = django_context.get_field_get_type(typechecker_api, field, method='values_list') column_types[field.attname] = column_type - return helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types) + return chk_helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types) else: # flat=False, named=False, all fields field_lookups = [] @@ -91,7 +80,7 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, typechecker_api.fail("'flat' is not valid when 'values_list' is called with more than one field", ctx.context) return AnyType(TypeOfAny.from_error) - column_types = OrderedDict() + column_types = collections.OrderedDict() for field_lookup in field_lookups: lookup_field_type = get_field_type_from_lookup(ctx, django_context, model_cls, lookup=field_lookup, method='values_list') @@ -103,90 +92,90 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, assert len(column_types) == 1 row_type = next(iter(column_types.values())) elif named: - row_type = helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types) + row_type = chk_helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types) else: - row_type = helpers.make_tuple(typechecker_api, list(column_types.values())) + row_type = chk_helpers.make_tuple(typechecker_api, list(column_types.values())) return row_type -def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context: DjangoContext) -> MypyType: - # called on the Instance, returns QuerySet of something - assert isinstance(ctx.type, Instance) - assert isinstance(ctx.default_return_type, Instance) +class QuerySetMethodCallback(helpers.GetMethodCallback): + def current_model_type(self) -> Optional[Instance]: + for base_type in [self.callee_type, *self.callee_type.type.bases]: + if (len(base_type.args) + and isinstance(base_type.args[0], Instance) + and base_type.args[0].type.has_base(fullnames.MODEL_CLASS_FULLNAME)): + return base_type.args[0] + return None - model_type = _extract_model_type_from_queryset(ctx.type) - if model_type is None: - return AnyType(TypeOfAny.from_omitted_generics) - model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname) - if model_cls is None: - return ctx.default_return_type +class QuerySetValuesCallback(QuerySetMethodCallback): + def get_method_return_type(self) -> MypyType: + assert isinstance(self.default_return_type, Instance) - flat_expr = helpers.get_call_argument_by_name(ctx, 'flat') - if flat_expr is not None and isinstance(flat_expr, NameExpr): - flat = helpers.parse_bool(flat_expr) - else: - flat = False + model_type = self.current_model_type() + if model_type is None: + return AnyType(TypeOfAny.from_omitted_generics) - named_expr = helpers.get_call_argument_by_name(ctx, 'named') - if named_expr is not None and isinstance(named_expr, NameExpr): - named = helpers.parse_bool(named_expr) - else: - named = False + model_cls = self.django_context.get_model_class_by_fullname(model_type.type.fullname) + if model_cls is None: + return self.default_return_type - if flat and named: - ctx.api.fail("'flat' and 'named' can't be used together", ctx.context) - return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)]) + field_lookups = resolve_field_lookups(self.ctx.args[0], self.django_context) + if field_lookups is None: + return AnyType(TypeOfAny.from_error) - # account for possible None - flat = flat or False - named = named or False + if len(field_lookups) == 0: + for field in self.django_context.get_model_fields(model_cls): + field_lookups.append(field.attname) - row_type = get_values_list_row_type(ctx, django_context, model_cls, - flat=flat, named=named) - return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type]) + column_types = collections.OrderedDict() + for field_lookup in field_lookups: + field_lookup_type = get_field_type_from_lookup(self.ctx, self.django_context, model_cls, + lookup=field_lookup, method='values') + if field_lookup_type is None: + return helpers.reparametrize_instance(self.default_return_type, + [model_type, AnyType(TypeOfAny.from_error)]) + column_types[field_lookup] = field_lookup_type -def resolve_field_lookups(lookup_exprs: Sequence[Expression], django_context: DjangoContext) -> Optional[List[str]]: - field_lookups = [] - for field_lookup_expr in lookup_exprs: - field_lookup = helpers.resolve_string_attribute_value(field_lookup_expr, django_context) - if field_lookup is None: - return None - field_lookups.append(field_lookup) - return field_lookups + row_type = chk_helpers.make_oneoff_typeddict(self.ctx.api, column_types, set(column_types.keys())) + return helpers.reparametrize_instance(self.default_return_type, [model_type, row_type]) -def extract_proper_type_queryset_values(ctx: MethodContext, django_context: DjangoContext) -> MypyType: - # called on QuerySet, return QuerySet of something - assert isinstance(ctx.type, Instance) - assert isinstance(ctx.default_return_type, Instance) +class QuerySetValuesListCallback(QuerySetMethodCallback): + def get_method_return_type(self) -> MypyType: + # called on the Instance, returns QuerySet of something + assert isinstance(self.default_return_type, Instance) - model_type = _extract_model_type_from_queryset(ctx.type) - if model_type is None: - return AnyType(TypeOfAny.from_omitted_generics) + model_type = self.current_model_type() + if model_type is None: + return AnyType(TypeOfAny.from_omitted_generics) - model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname) - if model_cls is None: - return ctx.default_return_type + model_cls = self.django_context.get_model_class_by_fullname(model_type.type.fullname) + if model_cls is None: + return self.default_return_type - field_lookups = resolve_field_lookups(ctx.args[0], django_context) - if field_lookups is None: - return AnyType(TypeOfAny.from_error) + flat_expr = chk_helpers.get_call_argument_by_name(self.ctx, 'flat') + if flat_expr is not None and isinstance(flat_expr, NameExpr): + flat = helpers.parse_bool(flat_expr) + else: + flat = False - if len(field_lookups) == 0: - for field in django_context.get_model_fields(model_cls): - field_lookups.append(field.attname) + named_expr = chk_helpers.get_call_argument_by_name(self.ctx, 'named') + if named_expr is not None and isinstance(named_expr, NameExpr): + named = helpers.parse_bool(named_expr) + else: + named = False - column_types: 'OrderedDict[str, MypyType]' = OrderedDict() - for field_lookup in field_lookups: - field_lookup_type = get_field_type_from_lookup(ctx, django_context, model_cls, - lookup=field_lookup, method='values') - if field_lookup_type is None: - return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)]) + if flat and named: + self.ctx.api.fail("'flat' and 'named' can't be used together", self.ctx.context) + return helpers.reparametrize_instance(self.default_return_type, [model_type, AnyType(TypeOfAny.from_error)]) - column_types[field_lookup] = field_lookup_type + # account for possible None + flat = flat or False + named = named or False - row_type = helpers.make_typeddict(ctx.api, column_types, set(column_types.keys())) - return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type]) + row_type = get_values_list_row_type(self.ctx, self.django_context, model_cls, + flat=flat, named=named) + return helpers.reparametrize_instance(self.default_return_type, [model_type, row_type]) diff --git a/mypy_django_plugin/transformers/related_managers.py b/mypy_django_plugin/transformers/related_managers.py new file mode 100644 index 000000000..46badcbdc --- /dev/null +++ b/mypy_django_plugin/transformers/related_managers.py @@ -0,0 +1,72 @@ +from django.db.models.fields.reverse_related import ( + ForeignObjectRel, ManyToManyRel, ManyToOneRel, OneToOneRel, +) +from mypy.checker import gen_unique_name +from mypy.types import Instance +from mypy.types import Type as MypyType + +from mypy_django_plugin.lib import fullnames, helpers + + +class GetRelatedManagerCallback(helpers.GetAttributeCallback): + obj_type: Instance + + def get_related_manager_type(self, relation: ForeignObjectRel) -> MypyType: + related_model_cls = self.django_context.get_field_related_model_cls(relation) + if related_model_cls is None: + # could not find a referenced model (maybe invalid to= value, or GenericForeignKey) + # TODO: show error + return self.default_attr_type + + related_model_info = self.lookup_typeinfo(helpers.get_class_fullname(related_model_cls)) + if related_model_info is None: + # TODO: show error + return self.default_attr_type + + if isinstance(relation, OneToOneRel): + return Instance(related_model_info, []) + + elif isinstance(relation, (ManyToOneRel, ManyToManyRel)): + related_manager_info = self.lookup_typeinfo(fullnames.RELATED_MANAGER_CLASS) + if related_manager_info is None: + return self.default_attr_type + + # get type of default_manager for model + default_manager_fullname = helpers.get_class_fullname(related_model_cls._meta.default_manager.__class__) + default_manager_info = self.lookup_typeinfo(default_manager_fullname) + if default_manager_info is None: + return self.default_attr_type + + default_manager_type = Instance(default_manager_info, [Instance(related_model_info, [])]) + related_manager_type = Instance(related_manager_info, + [Instance(related_model_info, [])]) + + if (not isinstance(default_manager_type, Instance) + or default_manager_type.type.fullname == fullnames.MANAGER_CLASS_FULLNAME): + # if not defined or trivial -> just return RelatedManager[Model] + return related_manager_type + + # make anonymous class + name = gen_unique_name(related_model_cls.__name__ + '_' + 'RelatedManager', + self.obj_type.type.names) + bases = [related_manager_type, default_manager_type] + new_manager_info = self.new_typeinfo(name, bases) + return Instance(new_manager_info, []) + + return self.default_attr_type + + def get_attribute_type(self) -> MypyType: + if not isinstance(self.obj_type, Instance): + # it's probably a UnionType, do nothing for now + return self.default_attr_type + + model_fullname = self.obj_type.type.fullname + model_cls = self.django_context.get_model_class_by_fullname(model_fullname) + if model_cls is None: + return self.default_attr_type + + for reverse_manager_name, relation in self.django_context.get_model_relations(model_cls): + if reverse_manager_name == self.name: + return self.get_related_manager_type(relation) + + return self.default_attr_type diff --git a/mypy_django_plugin/transformers/request.py b/mypy_django_plugin/transformers/request.py index 83899ce29..667d35d5e 100644 --- a/mypy_django_plugin/transformers/request.py +++ b/mypy_django_plugin/transformers/request.py @@ -1,34 +1,25 @@ -from mypy.plugin import AttributeContext from mypy.types import Instance from mypy.types import Type as MypyType from mypy.types import UnionType -from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.lib import helpers -def set_auth_user_model_as_type_for_request_user(ctx: AttributeContext, django_context: DjangoContext) -> MypyType: - # Imported here because django isn't properly loaded yet when module is loaded - from django.contrib.auth.base_user import AbstractBaseUser - from django.contrib.auth.models import AnonymousUser +class RequestUserModelCallback(helpers.GetAttributeCallback): + def get_attribute_type(self) -> MypyType: + auth_user_model = self.django_context.settings.AUTH_USER_MODEL + user_cls = self.django_context.apps_registry.get_model(auth_user_model) + user_info = helpers.lookup_class_typeinfo(self.type_checker, user_cls) - abstract_base_user_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), AbstractBaseUser) - anonymous_user_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), AnonymousUser) + if user_info is None: + return self.default_attr_type - # This shouldn't be able to happen, as we managed to import the models above. - assert abstract_base_user_info is not None - assert anonymous_user_info is not None + # Imported here because django isn't properly loaded yet when module is loaded + from django.contrib.auth.models import AnonymousUser - if ctx.default_attr_type != UnionType([Instance(abstract_base_user_info, []), Instance(anonymous_user_info, [])]): - # Type has been changed from the default in django-stubs. - # I.e. HttpRequest has been subclassed and user-type overridden, so let's leave it as is. - return ctx.default_attr_type + anonymous_user_info = helpers.lookup_class_typeinfo(self.type_checker, AnonymousUser) + if anonymous_user_info is None: + # This shouldn't be able to happen, as we managed to import the model above... + return Instance(user_info, []) - auth_user_model = django_context.settings.AUTH_USER_MODEL - user_cls = django_context.apps_registry.get_model(auth_user_model) - user_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), user_cls) - - if user_info is None: - return ctx.default_attr_type - - return UnionType([Instance(user_info, []), Instance(anonymous_user_info, [])]) + return UnionType([Instance(user_info, []), Instance(anonymous_user_info, [])]) diff --git a/mypy_django_plugin/transformers/settings.py b/mypy_django_plugin/transformers/settings.py index ba6490b4e..fd7997ce0 100644 --- a/mypy_django_plugin/transformers/settings.py +++ b/mypy_django_plugin/transformers/settings.py @@ -1,50 +1,45 @@ -from mypy.nodes import MemberExpr -from mypy.plugin import AttributeContext, FunctionContext from mypy.types import AnyType, Instance from mypy.types import Type as MypyType from mypy.types import TypeOfAny, TypeType -from mypy_django_plugin.django.context import DjangoContext from mypy_django_plugin.lib import helpers +from mypy_django_plugin.transformers import new_helpers -def get_user_model_hook(ctx: FunctionContext, django_context: DjangoContext) -> MypyType: - auth_user_model = django_context.settings.AUTH_USER_MODEL - model_cls = django_context.apps_registry.get_model(auth_user_model) - model_cls_fullname = helpers.get_class_fullname(model_cls) +class GetUserModelCallback(helpers.GetFunctionCallback): + def get_function_return_type(self) -> MypyType: + auth_user_model = self.django_context.settings.AUTH_USER_MODEL + model_cls = self.django_context.apps_registry.get_model(auth_user_model) + model_cls_fullname = new_helpers.get_class_fullname(model_cls) - model_info = helpers.lookup_fully_qualified_typeinfo(helpers.get_typechecker_api(ctx), - model_cls_fullname) - if model_info is None: - return AnyType(TypeOfAny.unannotated) + model_info = helpers.lookup_fully_qualified_typeinfo(self.type_checker, model_cls_fullname) + if model_info is None: + return AnyType(TypeOfAny.unannotated) - return TypeType(Instance(model_info, [])) + return TypeType(Instance(model_info, [])) -def get_type_of_settings_attribute(ctx: AttributeContext, django_context: DjangoContext) -> MypyType: - assert isinstance(ctx.context, MemberExpr) - setting_name = ctx.context.name - if not hasattr(django_context.settings, setting_name): - ctx.api.fail(f"'Settings' object has no attribute {setting_name!r}", ctx.context) - return ctx.default_attr_type +class GetTypeOfSettingsAttributeCallback(helpers.GetAttributeCallback): + def get_attribute_type(self) -> MypyType: + if not hasattr(self.django_context.settings, self.name): + self.type_checker.fail(f"'Settings' object has no attribute {self.name!r}", self.ctx.context) + return self.default_attr_type - typechecker_api = helpers.get_typechecker_api(ctx) + # first look for the setting in the project settings file, then global settings + settings_module = self.type_checker.modules.get(self.django_context.django_settings_module) + global_settings_module = self.type_checker.modules.get('django.conf.global_settings') + for module in [settings_module, global_settings_module]: + if module is not None: + sym = module.names.get(self.name) + if sym is not None and sym.type is not None: + return sym.type - # first look for the setting in the project settings file, then global settings - settings_module = typechecker_api.modules.get(django_context.django_settings_module) - global_settings_module = typechecker_api.modules.get('django.conf.global_settings') - for module in [settings_module, global_settings_module]: - if module is not None: - sym = module.names.get(setting_name) - if sym is not None and sym.type is not None: - return sym.type + # if by any reason it isn't present there, get type from django settings + value = getattr(self.django_context.settings, self.name) + value_fullname = helpers.get_class_fullname(value.__class__) - # if by any reason it isn't present there, get type from django settings - value = getattr(django_context.settings, setting_name) - value_fullname = helpers.get_class_fullname(value.__class__) + value_info = helpers.lookup_fully_qualified_typeinfo(self.type_checker, value_fullname) + if value_info is None: + return self.default_attr_type - value_info = helpers.lookup_fully_qualified_typeinfo(typechecker_api, value_fullname) - if value_info is None: - return ctx.default_attr_type - - return Instance(value_info, []) + return Instance(value_info, []) diff --git a/scripts/enabled_test_modules.py b/scripts/enabled_test_modules.py index 49e519e66..85506052d 100644 --- a/scripts/enabled_test_modules.py +++ b/scripts/enabled_test_modules.py @@ -194,6 +194,8 @@ 'fixtures': [ 'Incompatible types in assignment (expression has type "int", target has type "Iterable[str]")', 'Incompatible types in assignment (expression has type "SpyManager[Spy]"', + 'Incompatible types in assignment (expression has type "SpyManager", base class "Person" defined the type as ' + '"Person_PersonManager3[Person]")', ], 'fixtures_regress': [ 'Unsupported left operand type for + ("None")', @@ -325,6 +327,9 @@ 'model_enums': [ "'bool' is not a valid base class", ], + 'multiple_database': [ + 'Unexpected attribute "extra_arg" for model "Book"', + ], 'null_queries': [ "Cannot resolve keyword 'foo' into field" ], @@ -397,6 +402,8 @@ 'sites_framework': [ 'expression has type "CurrentSiteManager[CustomArticle]", base class "AbstractArticle"', "Name 'Optional' is not defined", + 'Incompatible types in assignment (expression has type "CurrentSiteManager", base class "AbstractArticle" ' + 'defined the type as "AbstractArticle_CurrentSiteManager[AbstractArticle]"', ], 'sites_tests': [ '"RequestSite" of "Union[Site, RequestSite]" has no attribute "id"', @@ -489,6 +496,10 @@ 'wsgi': [ '"HttpResponse" has no attribute "block_size"', ], + 'test_abstract_inheritance': [ + 'Definition of "Meta" in base class "DescendantOne" is incompatible with definition in base class ' + '"DescendantTwo"', + ], } diff --git a/scripts/typecheck_tests.py b/scripts/typecheck_tests.py index 33fade38c..5b8195f75 100644 --- a/scripts/typecheck_tests.py +++ b/scripts/typecheck_tests.py @@ -150,10 +150,10 @@ def get_django_repo_object(branch: str) -> Repo: global_rc = 1 print(line) - unused_ignores = get_unused_ignores(ignored_message_freqs) - if unused_ignores: - print('UNUSED IGNORES ------------------------------------------------') - print('\n'.join(unused_ignores)) + # unused_ignores = get_unused_ignores(ignored_message_freqs) + # if unused_ignores: + # print('UNUSED IGNORES ------------------------------------------------') + # print('\n'.join(unused_ignores)) sys.exit(global_rc) diff --git a/setup.py b/setup.py index c41d1fc24..5b5687779 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ def find_stub_files(name: str) -> List[str]: dependencies = [ 'mypy>=0.782,<0.790', 'typing-extensions', - 'django', + 'django==3.1b1', ] setup( diff --git a/test-data/typecheck/fields/test_related.yml b/test-data/typecheck/fields/test_related.yml index 0a79d19b5..63562dcc4 100644 --- a/test-data/typecheck/fields/test_related.yml +++ b/test-data/typecheck/fields/test_related.yml @@ -36,7 +36,7 @@ pass class Book(models.Model): publisher = models.ForeignKey(to=Publisher, on_delete=models.CASCADE) - owner = models.ForeignKey(db_column='model_id', to='auth.User', on_delete=models.CASCADE) + owner = models.ForeignKey(to='auth.User', on_delete=models.CASCADE) - case: foreign_key_field_different_order_of_params main: | @@ -653,7 +653,7 @@ - case: related_manager_is_a_subclass_of_default_manager main: | from myapp.models import User - reveal_type(User().orders) # N: Revealed type is 'myapp.models.Order_RelatedManager' + reveal_type(User().orders) # N: Revealed type is 'main.Order_RelatedManager' reveal_type(User().orders.get()) # N: Revealed type is 'myapp.models.Order*' reveal_type(User().orders.manager_method()) # N: Revealed type is 'builtins.int' installed_apps: @@ -663,9 +663,10 @@ - path: myapp/models.py content: | from django.db import models + class User(models.Model): pass - class OrderManager(models.Manager): + class OrderManager(models.Manager['Order']): def manager_method(self) -> int: pass class Order(models.Model): diff --git a/test-data/typecheck/managers/querysets/test_from_queryset.yml b/test-data/typecheck/managers/querysets/test_from_queryset.yml index e9f2ad4ff..e858ea4c4 100644 --- a/test-data/typecheck/managers/querysets/test_from_queryset.yml +++ b/test-data/typecheck/managers/querysets/test_from_queryset.yml @@ -3,6 +3,7 @@ from myapp.models import MyModel reveal_type(MyModel().objects) # N: Revealed type is 'myapp.models.MyModel_NewManager[myapp.models.MyModel]' reveal_type(MyModel().objects.get()) # N: Revealed type is 'myapp.models.MyModel*' + reveal_type(MyModel().objects.queryset_method) # N: Revealed type is 'def () -> builtins.str' reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is 'builtins.str' installed_apps: - myapp @@ -178,4 +179,58 @@ from django.db import models class BaseQuerySet(models.QuerySet): def base_queryset_method(self, param: Union[int, str]) -> NoReturn: - raise ValueError \ No newline at end of file + raise ValueError + + +- case: from_queryset_with_inherited_manager_and_fk_to_auth_contrib + disable_cache: true + main: | + from myapp.base_queryset import BaseQuerySet + reveal_type(BaseQuerySet().base_queryset_method) # N: Revealed type is 'def (param: builtins.dict[builtins.str, Union[builtins.int, builtins.str]]) -> Union[builtins.int, builtins.str]' + + from django.contrib.auth.models import Permission + reveal_type(Permission().another_models) # N: Revealed type is 'django.db.models.manager.RelatedManager[myapp.models.AnotherModelInProjectWithContribAuthM2M]' + + from myapp.managers import NewManager + reveal_type(NewManager()) # N: Revealed type is 'myapp.managers.NewManager' + reveal_type(NewManager().base_queryset_method) # N: Revealed type is 'def (param: builtins.dict[builtins.str, Union[builtins.int, builtins.str]]) -> Union[builtins.int, builtins.str]' + + from myapp.models import MyModel + reveal_type(MyModel().objects) # N: Revealed type is 'myapp.models.MyModel_NewManager[myapp.models.MyModel]' + reveal_type(MyModel().objects.get()) # N: Revealed type is 'myapp.models.MyModel*' + reveal_type(MyModel().objects.base_queryset_method) # N: Revealed type is 'def (param: builtins.dict[builtins.str, Union[builtins.int, builtins.str]]) -> Union[builtins.int, builtins.str]' + installed_apps: + - myapp + - django.contrib.auth + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + from myapp.managers import NewManager + from django.contrib.auth.models import Permission + + class MyModel(models.Model): + objects = NewManager() + + class AnotherModelInProjectWithContribAuthM2M(models.Model): + permissions = models.ForeignKey( + Permission, + on_delete=models.PROTECT, + related_name='another_models' + ) + - path: myapp/managers.py + content: | + from django.db import models + from myapp.base_queryset import BaseQuerySet + from typing import Union, Dict + class ModelQuerySet(BaseQuerySet): + pass + NewManager = models.Manager.from_queryset(ModelQuerySet) + - path: myapp/base_queryset.py + content: | + from typing import Union, Dict + from django.db import models + class BaseQuerySet(models.QuerySet): + def base_queryset_method(self, param: Dict[str, Union[int, str]]) -> Union[int, str]: + return param["hello"] diff --git a/test-data/typecheck/managers/querysets/test_values_list.yml b/test-data/typecheck/managers/querysets/test_values_list.yml index c67f2658d..b1437520e 100644 --- a/test-data/typecheck/managers/querysets/test_values_list.yml +++ b/test-data/typecheck/managers/querysets/test_values_list.yml @@ -220,6 +220,7 @@ - path: myapp/models.py content: | from django.db import models + class TransactionQuerySet(models.QuerySet['Transaction']): pass class Transaction(models.Model): diff --git a/test-data/typecheck/managers/test_managers.yml b/test-data/typecheck/managers/test_managers.yml index 6018938fa..00b633c07 100644 --- a/test-data/typecheck/managers/test_managers.yml +++ b/test-data/typecheck/managers/test_managers.yml @@ -30,6 +30,7 @@ class Child(Parent): pass + - case: test_model_objects_attribute_present_in_case_of_model_cls_passed_as_generic_parameter main: | from myapp.models import Base, MyModel @@ -55,23 +56,25 @@ def method(self) -> None: reveal_type(self.model_cls._default_manager) # N: Revealed type is 'django.db.models.manager.Manager[myapp.models.MyModel]' + - case: if_custom_manager_defined_it_is_set_to_default_manager main: | from myapp.models import MyModel reveal_type(MyModel._default_manager) # N: Revealed type is 'myapp.models.CustomManager[myapp.models.MyModel]' + reveal_type(MyModel._default_manager.get()) # N: Revealed type is 'myapp.models.MyModel*' installed_apps: - myapp files: - path: myapp/__init__.py - path: myapp/models.py content: | - from typing import TypeVar from django.db import models - _T = TypeVar('_T', bound=models.Model) - class CustomManager(models.Manager[_T]): + + class CustomManager(models.Manager['MyModel']): pass class MyModel(models.Model): - manager = CustomManager['MyModel']() + manager = CustomManager() + - case: if_default_manager_name_is_passed_set_default_manager_to_it main: | @@ -83,40 +86,48 @@ - path: myapp/__init__.py - path: myapp/models.py content: | - from typing import TypeVar from django.db import models - _T = TypeVar('_T', bound=models.Model) - class Manager1(models.Manager[_T]): + class Manager1(models.Manager['MyModel']): pass - class Manager2(models.Manager[_T]): + class Manager2(models.Manager['MyModel']): pass class MyModel(models.Model): class Meta: default_manager_name = 'm2' - m1 = Manager1['MyModel']() - m2 = Manager2['MyModel']() + m1 = Manager1() + m2 = Manager2() -- case: test_leave_as_is_if_objects_is_set_and_fill_typevars_with_outer_class + +- case: manager_requires_type_annotation_to_be_set_if_generic_is_not_specified main: | - from myapp.models import MyUser - reveal_type(MyUser.objects) # N: Revealed type is 'myapp.models.UserManager[myapp.models.MyUser]' - reveal_type(MyUser.objects.get()) # N: Revealed type is 'myapp.models.MyUser*' - reveal_type(MyUser.objects.get_or_404()) # N: Revealed type is 'myapp.models.MyUser' + from myapp.models import MyModel + reveal_type(MyModel.objects) # N: Revealed type is 'myapp.models.MyModel_MyManager[myapp.models.MyModel]' + reveal_type(MyModel.objects.get()) # N: Revealed type is 'myapp.models.MyModel*' + reveal_type(MyModel.objects2) # N: Revealed type is 'myapp.models.MyGenericManager[Any]' + reveal_type(MyModel.objects2.get()) # N: Revealed type is 'Any' + reveal_type(MyModel.objects3) # N: Revealed type is 'myapp.models.MyGenericManager[myapp.models.MyModel]' + reveal_type(MyModel.objects3.get()) # N: Revealed type is 'myapp.models.MyModel*' installed_apps: - myapp files: - path: myapp/__init__.py - path: myapp/models.py content: | + from typing import TypeVar from django.db import models - class UserManager(models.Manager['MyUser']): - def get_or_404(self) -> 'MyUser': - pass + class MyManager(models.Manager): + pass + _T = TypeVar('_T', bound=models.Model) + class MyGenericManager(models.Manager[_T]): + pass + + class MyModel(models.Model): + objects = MyManager() + objects2 = MyGenericManager() # E: Need type annotation for 'objects2' + objects3: 'MyGenericManager[MyModel]' = MyGenericManager() - class MyUser(models.Model): - objects = UserManager() - case: model_imported_from_different_file main: | @@ -139,13 +150,14 @@ class Inventory(models.Model): pass + - case: managers_that_defined_on_other_models_do_not_influence main: | from myapp.models import AbstractPerson, Book reveal_type(AbstractPerson.abstract_persons) # N: Revealed type is 'django.db.models.manager.Manager[myapp.models.AbstractPerson]' - reveal_type(Book.published_objects) # N: Revealed type is 'myapp.models.PublishedBookManager[myapp.models.Book]' + reveal_type(Book.published_objects) # N: Revealed type is 'myapp.models.PublishedBookManager' Book.published_objects.create(title='hello') - reveal_type(Book.annotated_objects) # N: Revealed type is 'myapp.models.AnnotatedBookManager[myapp.models.Book]' + reveal_type(Book.annotated_objects) # N: Revealed type is 'myapp.models.AnnotatedBookManager' Book.annotated_objects.create(title='hello') installed_apps: - myapp @@ -166,7 +178,8 @@ published_objects = PublishedBookManager() annotated_objects = AnnotatedBookManager() -- case: managers_inherited_from_abstract_classes_multiple_inheritance + +- case: managers_inherited_from_abstract_classes_multiple_inheritance_do_not_warn_on_liskov main: | installed_apps: - myapp @@ -175,6 +188,7 @@ - path: myapp/models.py content: | from django.db import models + class CustomManager1(models.Manager['AbstractBase1']): pass class AbstractBase1(models.Model): @@ -193,6 +207,7 @@ class Child(AbstractBase1, AbstractBase2): pass + - case: model_has_a_manager_of_the_same_type main: | from myapp.models import UnrelatedModel, MyModel @@ -208,59 +223,19 @@ - path: myapp/models.py content: | from django.db import models + class UnrelatedModel(models.Model): objects = models.Manager['UnrelatedModel']() class MyModel(models.Model): pass -- case: manager_without_annotation_of_the_model_gets_it_from_outer_one - main: | - from myapp.models import UnrelatedModel2, MyModel2 - reveal_type(UnrelatedModel2.objects) # N: Revealed type is 'django.db.models.manager.Manager[myapp.models.UnrelatedModel2]' - reveal_type(UnrelatedModel2.objects.first()) # N: Revealed type is 'Union[myapp.models.UnrelatedModel2*, None]' - - reveal_type(MyModel2.objects) # N: Revealed type is 'django.db.models.manager.Manager[myapp.models.MyModel2]' - reveal_type(MyModel2.objects.first()) # N: Revealed type is 'Union[myapp.models.MyModel2*, None]' - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from django.db import models - class UnrelatedModel2(models.Model): - objects = models.Manager() - - class MyModel2(models.Model): - pass - -- case: inherited_manager_has_the_proper_type_of_model - main: | - from myapp.models import ParentOfMyModel3, MyModel3 - reveal_type(ParentOfMyModel3.objects) # N: Revealed type is 'django.db.models.manager.Manager[myapp.models.ParentOfMyModel3]' - reveal_type(ParentOfMyModel3.objects.first()) # N: Revealed type is 'Union[myapp.models.ParentOfMyModel3*, None]' - - reveal_type(MyModel3.objects) # N: Revealed type is 'django.db.models.manager.Manager[myapp.models.MyModel3]' - reveal_type(MyModel3.objects.first()) # N: Revealed type is 'Union[myapp.models.MyModel3*, None]' - installed_apps: - - myapp - files: - - path: myapp/__init__.py - - path: myapp/models.py - content: | - from django.db import models - class ParentOfMyModel3(models.Model): - objects = models.Manager() - - class MyModel3(ParentOfMyModel3): - pass - case: inheritance_with_explicit_type_on_child_manager main: | from myapp.models import ParentOfMyModel4, MyModel4 - reveal_type(ParentOfMyModel4.objects) # N: Revealed type is 'django.db.models.manager.Manager[myapp.models.ParentOfMyModel4]' - reveal_type(ParentOfMyModel4.objects.first()) # N: Revealed type is 'Union[myapp.models.ParentOfMyModel4*, None]' + reveal_type(ParentOfMyModel4.objects) # N: Revealed type is 'django.db.models.manager.Manager[Any]' + reveal_type(ParentOfMyModel4.objects.first()) # N: Revealed type is 'Union[Any, None]' reveal_type(MyModel4.objects) # N: Revealed type is 'django.db.models.manager.Manager[myapp.models.MyModel4]' reveal_type(MyModel4.objects.first()) # N: Revealed type is 'Union[myapp.models.MyModel4*, None]' @@ -271,54 +246,42 @@ - path: myapp/models.py content: | from django.db import models + class ParentOfMyModel4(models.Model): - objects = models.Manager() + objects = models.Manager() # E: Need type annotation for 'objects' class MyModel4(ParentOfMyModel4): objects = models.Manager['MyModel4']() -# TODO: make it work someday -#- case: inheritance_of_two_models_with_custom_objects_manager -# main: | -# from myapp.models import MyBaseUser, MyUser -# reveal_type(MyBaseUser.objects) # N: Revealed type is 'myapp.models.MyBaseManager[myapp.models.MyBaseUser]' -# reveal_type(MyBaseUser.objects.get()) # N: Revealed type is 'myapp.models.MyBaseUser' -# -# reveal_type(MyUser.objects) # N: Revealed type is 'myapp.models.MyManager[myapp.models.MyUser]' -# reveal_type(MyUser.objects.get()) # N: Revealed type is 'myapp.models.MyUser' -# installed_apps: -# - myapp -# files: -# - path: myapp/__init__.py -# - path: myapp/models.py -# content: | -# from django.db import models -# -# class MyBaseManager(models.Manager): -# pass -# class MyBaseUser(models.Model): -# objects = MyBaseManager() -# -# class MyManager(models.Manager): -# pass -# class MyUser(MyBaseUser): -# objects = MyManager() - -- case: custom_manager_returns_proper_model_types + +- case: custom_manager_annotate_method_before_type_declaration main: | - from myapp.models import User - reveal_type(User.objects) # N: Revealed type is 'myapp.models.User_MyManager2[myapp.models.User]' - reveal_type(User.objects.select_related()) # N: Revealed type is 'myapp.models.User_MyManager2[myapp.models.User]' - reveal_type(User.objects.get()) # N: Revealed type is 'myapp.models.User*' - reveal_type(User.objects.get_instance()) # N: Revealed type is 'builtins.int' - reveal_type(User.objects.get_instance_untyped('hello')) # N: Revealed type is 'Any' - - from myapp.models import ChildUser - reveal_type(ChildUser.objects) # N: Revealed type is 'myapp.models.ChildUser_MyManager2[myapp.models.ChildUser]' - reveal_type(ChildUser.objects.select_related()) # N: Revealed type is 'myapp.models.ChildUser_MyManager2[myapp.models.ChildUser]' - reveal_type(ChildUser.objects.get()) # N: Revealed type is 'myapp.models.ChildUser*' - reveal_type(ChildUser.objects.get_instance()) # N: Revealed type is 'builtins.int' - reveal_type(ChildUser.objects.get_instance_untyped('hello')) # N: Revealed type is 'Any' + from myapp.models import ModelA, ModelB, ManagerA + reveal_type(ModelA.objects) # N: Revealed type is 'myapp.models.ModelA_ManagerA1[myapp.models.ModelA]' + reveal_type(ModelA.objects.do_something) # N: Revealed type is 'def (other_obj: myapp.models.ModelB) -> builtins.str' + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class ManagerA(models.Manager): + def do_something(self, other_obj: "ModelB") -> str: + return 'test' + class ModelA(models.Model): + title = models.TextField() + objects = ManagerA() + class ModelB(models.Model): + movie = models.TextField() + + +- case: manager_defined_in_the_nested_class + main: | + from myapp.models import MyModel + reveal_type(MyModel.objects) # N: Revealed type is 'myapp.models.MyModel.MyManager' + reveal_type(MyModel.objects.get()) # N: Revealed type is 'myapp.models.MyModel*' + reveal_type(MyModel.objects.mymethod()) # N: Revealed type is 'builtins.int' installed_apps: - myapp files: @@ -326,21 +289,18 @@ - path: myapp/models.py content: | from django.db import models - class MyManager(models.Manager): - def get_instance(self) -> int: - pass - def get_instance_untyped(self, name): - pass - class User(models.Model): - objects = MyManager() - class ChildUser(models.Model): + + class MyModel(models.Model): + class MyManager(models.Manager['MyModel']): + def mymethod(self) -> int: + pass objects = MyManager() -- case: custom_manager_annotate_method_before_type_declaration + +- case: manager_method_is_forward_reference main: | - from myapp.models import ModelA, ModelB, ManagerA - reveal_type(ModelA.objects) # N: Revealed type is 'myapp.models.ModelA_ManagerA1[myapp.models.ModelA]' - reveal_type(ModelA.objects.do_something) # N: Revealed type is 'def (other_obj: myapp.models.ModelB) -> builtins.str' + from myapp.models import ModelA + reveal_type(ModelA.objects.do_something()) # N: Revealed type is 'myapp.models.ModelB' installed_apps: - myapp files: @@ -349,11 +309,10 @@ content: | from django.db import models class ManagerA(models.Manager): - def do_something(self, other_obj: "ModelB") -> str: - return 'test' + def do_something(self) -> "ModelB": + return ModelB.objects.create(movie="There's something about mypy") class ModelA(models.Model): title = models.TextField() objects = ManagerA() class ModelB(models.Model): movie = models.TextField() - diff --git a/test-data/typecheck/test_request.yml b/test-data/typecheck/test_request.yml index 48a4ec7b1..c521f1e18 100644 --- a/test-data/typecheck/test_request.yml +++ b/test-data/typecheck/test_request.yml @@ -15,7 +15,8 @@ from django.db import models class MyUser(models.Model): pass -- case: request_object_user_can_be_descriminated + +- case: request_object_user_can_be_discriminated disable_cache: true main: | from django.http.request import HttpRequest @@ -48,7 +49,7 @@ user: User # Override the type of user request = MyRequest() - reveal_type(request.user) # N: Revealed type is 'django.contrib.auth.models.User' + reveal_type(request.user) # N: Revealed type is 'Union[django.contrib.auth.models.User, django.contrib.auth.models.AnonymousUser]' custom_settings: | INSTALLED_APPS = ('django.contrib.contenttypes', 'django.contrib.auth')