diff --git a/mypy/checker.py b/mypy/checker.py index 698d4de078f3..1172e1972db9 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -24,8 +24,8 @@ YieldFromExpr, NamedTupleExpr, TypedDictExpr, SetComprehension, DictionaryComprehension, ComplexExpr, EllipsisExpr, TypeAliasExpr, RefExpr, YieldExpr, BackquoteExpr, ImportFrom, ImportAll, ImportBase, - AwaitExpr, PromoteExpr, - ARG_POS, + AwaitExpr, PromoteExpr, Node, + ARG_POS, MDEF, CONTRAVARIANT, COVARIANT) from mypy import nodes from mypy.types import ( @@ -44,7 +44,8 @@ restrict_subtype_away, is_subtype_ignoring_tvars ) from mypy.maptype import map_instance_to_supertype -from mypy.semanal import fill_typevars, set_callable_name, refers_to_fullname +from mypy.typevars import fill_typevars, has_no_typevars +from mypy.semanal import set_callable_name, refers_to_fullname from mypy.erasetype import erase_typevars from mypy.expandtype import expand_type, expand_type_by_instance from mypy.visitor import NodeVisitor @@ -1102,6 +1103,12 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type infer_lvalue_type) else: lvalue_type, index_lvalue, inferred = self.check_lvalue(lvalue) + + if isinstance(lvalue, NameExpr): + if self.check_compatibility_all_supers(lvalue, lvalue_type, rvalue): + # We hit an error on this line; don't check for any others + return + if lvalue_type: if isinstance(lvalue_type, PartialType) and lvalue_type.type is None: # Try to infer a proper type for a variable with a partial None type. @@ -1156,6 +1163,123 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type self.infer_variable_type(inferred, lvalue, self.accept(rvalue), rvalue) + def check_compatibility_all_supers(self, lvalue: NameExpr, lvalue_type: Type, + rvalue: Expression) -> bool: + lvalue_node = lvalue.node + + # Check if we are a class variable with at least one base class + if (isinstance(lvalue_node, Var) and + lvalue.kind == MDEF and + len(lvalue_node.info.bases) > 0): + + for base in lvalue_node.info.mro[1:]: + # Only check __slots__ against the 'object' + # If a base class defines a Tuple of 3 elements, a child of + # this class should not be allowed to define it as a Tuple of + # anything other than 3 elements. The exception to this rule + # is __slots__, where it is allowed for any child class to + # redefine it. + if lvalue_node.name() == "__slots__" and base.fullname() != "builtins.object": + continue + + base_type, base_node = self.lvalue_type_from_base(lvalue_node, base) + + if base_type: + if not self.check_compatibility_super(lvalue, + lvalue_type, + rvalue, + base, + base_type, + base_node): + # Only show one error per variable; even if other + # base classes are also incompatible + return True + break + return False + + def check_compatibility_super(self, lvalue: NameExpr, lvalue_type: Type, rvalue: Expression, + base: TypeInfo, base_type: Type, base_node: Node) -> bool: + lvalue_node = lvalue.node + assert isinstance(lvalue_node, Var) + + # Do not check whether the rvalue is compatible if the + # lvalue had a type defined; this is handled by other + # parts, and all we have to worry about in that case is + # that lvalue is compatible with the base class. + compare_node = None # type: Node + if lvalue_type: + compare_type = lvalue_type + compare_node = lvalue.node + else: + compare_type = self.accept(rvalue, base_type) + if isinstance(rvalue, NameExpr): + compare_node = rvalue.node + if isinstance(compare_node, Decorator): + compare_node = compare_node.func + + if compare_type: + if (isinstance(base_type, CallableType) and + isinstance(compare_type, CallableType)): + base_static = is_node_static(base_node) + compare_static = is_node_static(compare_node) + + # In case compare_static is unknown, also check + # if 'definition' is set. The most common case for + # this is with TempNode(), where we lose all + # information about the real rvalue node (but only get + # the rvalue type) + if compare_static is None and compare_type.definition: + compare_static = is_node_static(compare_type.definition) + + # Compare against False, as is_node_static can return None + if base_static is False and compare_static is False: + # Class-level function objects and classmethods become bound + # methods: the former to the instance, the latter to the + # class + base_type = bind_self(base_type, self.scope.active_class()) + compare_type = bind_self(compare_type, self.scope.active_class()) + + # If we are a static method, ensure to also tell the + # lvalue it now contains a static method + if base_static and compare_static: + lvalue_node.is_staticmethod = True + + return self.check_subtype(compare_type, base_type, lvalue, + messages.INCOMPATIBLE_TYPES_IN_ASSIGNMENT, + 'expression has type', + 'base class "%s" defined the type as' % base.name()) + return True + + def lvalue_type_from_base(self, expr_node: Var, + base: TypeInfo) -> Tuple[Optional[Type], Optional[Node]]: + """For a NameExpr that is part of a class, walk all base classes and try + to find the first class that defines a Type for the same name.""" + expr_name = expr_node.name() + base_var = base.names.get(expr_name) + + if base_var: + base_node = base_var.node + base_type = base_var.type + if isinstance(base_node, Decorator): + base_node = base_node.func + base_type = base_node.type + + if base_type: + if not has_no_typevars(base_type): + instance = cast(Instance, self.scope.active_class()) + itype = map_instance_to_supertype(instance, base) + base_type = expand_type_by_instance(base_type, itype) + + if isinstance(base_type, CallableType) and isinstance(base_node, FuncDef): + # If we are a property, return the Type of the return + # value, not the Callable + if base_node.is_property: + base_type = base_type.ret_type + + return base_type, base_node + + return None, None + def check_assignment_to_multiple_lvalues(self, lvalues: List[Lvalue], rvalue: Expression, context: Context, infer_lvalue_type: bool = True) -> None: @@ -2835,6 +2959,18 @@ def is_valid_inferred_type_component(typ: Type) -> bool: return True +def is_node_static(node: Node) -> Optional[bool]: + """Find out if a node describes a static function method.""" + + if isinstance(node, FuncDef): + return node.is_static + + if isinstance(node, Var): + return node.is_staticmethod + + return None + + class Scope: # We keep two stacks combined, to maintain the relative order stack = None # type: List[Union[Type, FuncItem, MypyFile]] diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index f86050077e49..77a2b09770a6 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -40,7 +40,7 @@ from mypy.checkstrformat import StringFormatterChecker from mypy.expandtype import expand_type, expand_type_by_instance from mypy.util import split_module_names -from mypy.semanal import fill_typevars +from mypy.typevars import fill_typevars from mypy.visitor import ExpressionVisitor from mypy import experiments diff --git a/mypy/checkmember.py b/mypy/checkmember.py index dd477cd87fbe..5dc9d1c6b7ff 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -16,7 +16,7 @@ from mypy.maptype import map_instance_to_supertype from mypy.expandtype import expand_type_by_instance, expand_type from mypy.infer import infer_type_arguments -from mypy.semanal import fill_typevars +from mypy.typevars import fill_typevars from mypy import messages from mypy import subtypes MYPY = False diff --git a/mypy/semanal.py b/mypy/semanal.py index 82e0289b5d94..145117bfe667 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -67,6 +67,7 @@ IntExpr, FloatExpr, UnicodeExpr, EllipsisExpr, TempNode, COVARIANT, CONTRAVARIANT, INVARIANT, UNBOUND_IMPORTED, LITERAL_YES, ) +from mypy.typevars import has_no_typevars, fill_typevars from mypy.visitor import NodeVisitor from mypy.traverser import TraverserVisitor from mypy.errors import Errors, report_internal_error @@ -79,7 +80,6 @@ from mypy.typeanal import TypeAnalyser, TypeAnalyserPass3, analyze_type_alias from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError from mypy.sametypes import is_same_type -from mypy.erasetype import erase_typevars from mypy.options import Options from mypy import join @@ -3165,19 +3165,6 @@ def builtin_type(self, name: str, args: List[Type] = None) -> Instance: return Instance(sym.node, args or []) -def fill_typevars(typ: TypeInfo) -> Union[Instance, TupleType]: - """For a non-generic type, return instance type representing the type. - For a generic G type with parameters T1, .., Tn, return G[T1, ..., Tn]. - """ - tv = [] # type: List[Type] - for i in range(len(typ.type_vars)): - tv.append(TypeVarType(typ.defn.type_vars[i])) - inst = Instance(typ, tv) - if typ.tuple_type is None: - return inst - return typ.tuple_type.copy_modified(fallback=inst) - - def replace_implicit_first_type(sig: FunctionLike, new: Type) -> FunctionLike: if isinstance(sig, CallableType): return sig.copy_modified(arg_types=[new] + sig.arg_types[1:]) @@ -3549,7 +3536,3 @@ def find_fixed_callable_return(expr: Expression) -> Optional[CallableType]: if isinstance(t.ret_type, CallableType): return t.ret_type return None - - -def has_no_typevars(typ: Type) -> bool: - return is_same_type(typ, erase_typevars(typ)) diff --git a/mypy/typevars.py b/mypy/typevars.py new file mode 100644 index 000000000000..1bdb1049ebed --- /dev/null +++ b/mypy/typevars.py @@ -0,0 +1,24 @@ +from typing import Union + +from mypy.nodes import TypeInfo + +from mypy.erasetype import erase_typevars +from mypy.sametypes import is_same_type +from mypy.types import Instance, TypeVarType, TupleType, Type + + +def fill_typevars(typ: TypeInfo) -> Union[Instance, TupleType]: + """For a non-generic type, return instance type representing the type. + For a generic G type with parameters T1, .., Tn, return G[T1, ..., Tn]. + """ + tv = [] # type: List[Type] + for i in range(len(typ.type_vars)): + tv.append(TypeVarType(typ.defn.type_vars[i])) + inst = Instance(typ, tv) + if typ.tuple_type is None: + return inst + return typ.tuple_type.copy_modified(fallback=inst) + + +def has_no_typevars(typ: Type) -> bool: + return is_same_type(typ, erase_typevars(typ)) diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index 6719abe660d2..c71ffdc4270f 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -2453,3 +2453,309 @@ class B(object, A): # E: Cannot determine consistent method resolution order (MR # flags: --fast-parser class C(metaclass=int()): # E: Dynamic metaclass not supported for 'C' pass + +[case testVariableSubclass] +class A: + a = 1 # type: int +class B(A): + a = 1 +[out] + +[case testVariableSubclassAssignMismatch] +class A: + a = 1 # type: int +class B(A): + a = "a" +[out] +main:4: error: Incompatible types in assignment (expression has type "str", base class "A" defined the type as "int") + +[case testVariableSubclassAssignment] +class A: + a = None # type: int +class B(A): + def __init__(self) -> None: + self.a = "a" +[out] +main:5: error: Incompatible types in assignment (expression has type "str", variable has type "int") + +[case testVariableSubclassTypeOverwrite] +class A: + a = None # type: int +class B(A): + a = None # type: str +class C(B): + a = "a" +[out] +main:4: error: Incompatible types in assignment (expression has type "str", base class "A" defined the type as "int") + +[case testVariableSubclassTypeOverwriteImplicit] +class A: + a = 1 +class B(A): + a = None # type: str +[out] +main:4: error: Incompatible types in assignment (expression has type "str", base class "A" defined the type as "int") + +[case testVariableSuperUsage] +class A: + a = [] # type: list +class B(A): + a = [1, 2] +class C(B): + a = B.a + [3] +[builtins fixtures/list.pyi] +[out] + +[case testClassAllBases] +from typing import Union +class A: + a = None # type: Union[int, str] +class B(A): + a = 1 +class C(B): + a = "str" +class D(A): + a = "str" +[out] +main:7: error: Incompatible types in assignment (expression has type "str", base class "B" defined the type as "int") + +[case testVariableTypeVar] +from typing import TypeVar, Generic +T = TypeVar('T') +class A(Generic[T]): + a = None # type: T +class B(A[int]): + a = 1 + +[case testVariableTypeVarInvalid] +from typing import TypeVar, Generic +T = TypeVar('T') +class A(Generic[T]): + a = None # type: T +class B(A[int]): + a = "abc" +[out] +main:6: error: Incompatible types in assignment (expression has type "str", base class "A" defined the type as "int") + +[case testVariableTypeVarIndirectly] +from typing import TypeVar, Generic +T = TypeVar('T') +class A(Generic[T]): + a = None # type: T +class B(A[int]): + pass +class C(B): + a = "a" +[out] +main:8: error: Incompatible types in assignment (expression has type "str", base class "A" defined the type as "int") + +[case testVariableTypeVarList] +from typing import List, TypeVar, Generic +T = TypeVar('T') +class A(Generic[T]): + a = None # type: List[T] + b = None # type: List[T] +class B(A[int]): + a = [1] + b = [''] +[builtins fixtures/list.pyi] +[out] +main:8: error: List item 0 has incompatible type "str" + +[case testVariableMethod] +class A: + def a(self) -> None: pass + b = 1 +class B(A): + a = 1 + def b(self) -> None: pass +[out] +main:5: error: Incompatible types in assignment (expression has type "int", base class "A" defined the type as Callable[[A], None]) +main:6: error: Signature of "b" incompatible with supertype "A" + +[case testVariableProperty] +class A: + @property + def a(self) -> bool: pass +class B(A): + a = None # type: bool +class C(A): + a = True +class D(A): + a = 1 +[builtins fixtures/property.pyi] +[out] +main:9: error: Incompatible types in assignment (expression has type "int", base class "A" defined the type as "bool") + +[case testVariableOverwriteAny] +from typing import Any +class A: + a = 1 +class B(A): + a = 'x' # type: Any +[out] + +[case testInstanceMethodOverwrite] +class B(): + def n(self, a: int) -> None: pass +class C(B): + def m(self, a: int) -> None: pass + n = m +[out] + +[case testInstanceMethodOverwriteError] +class B(): + def n(self, a: int) -> None: pass +class C(B): + def m(self, a: str) -> None: pass + n = m +[out] +main:5: error: Incompatible types in assignment (expression has type Callable[[str], None], base class "B" defined the type as Callable[[int], None]) + +[case testInstanceMethodOverwriteTypevar] +from typing import Generic, TypeVar +T = TypeVar("T") +class B(Generic[T]): + def n(self, a: T) -> None: pass +class C(B[int]): + def m(self, a: int) -> None: pass + n = m + +[case testInstanceMethodOverwriteTwice] +class I: + def foo(self) -> None: pass +class A(I): + def foo(self) -> None: pass +class B(A): + def bar(self) -> None: pass + foo = bar +class C(B): + def bar(self) -> None: pass + foo = bar + +[case testClassMethodOverwrite] +class B(): + @classmethod + def n(self, a: int) -> None: pass +class C(B): + @classmethod + def m(self, a: int) -> None: pass + n = m +[builtins fixtures/classmethod.pyi] +[out] + +[case testClassMethodOverwriteError] +class B(): + @classmethod + def n(self, a: int) -> None: pass +class C(B): + @classmethod + def m(self, a: str) -> None: pass + n = m +[builtins fixtures/classmethod.pyi] +[out] +main:7: error: Incompatible types in assignment (expression has type Callable[[str], None], base class "B" defined the type as Callable[[int], None]) + +[case testClassSpec] +from typing import Callable +class A(): + b = None # type: Callable[[A, int], int] +class B(A): + def c(self, a: int) -> int: pass + b = c + +[case testClassSpecError] +from typing import Callable +class A(): + b = None # type: Callable[[A, int], int] +class B(A): + def c(self, a: str) -> int: pass + b = c +[out] +main:6: error: Incompatible types in assignment (expression has type Callable[[str], int], base class "A" defined the type as Callable[[int], int]) + +[case testClassStaticMethod] +class A(): + @staticmethod + def a(a: int) -> None: pass +class B(A): + @staticmethod + def b(a: str) -> None: pass + a = b +[builtins fixtures/staticmethod.pyi] +[out] +main:7: error: Incompatible types in assignment (expression has type Callable[[str], None], base class "A" defined the type as Callable[[int], None]) + +[case testClassStaticMethodIndirect] +class A(): + @staticmethod + def a(a: int) -> None: pass + c = a +class B(A): + @staticmethod + def b(a: str) -> None: pass + c = b +[builtins fixtures/staticmethod.pyi] +[out] +main:8: error: Incompatible types in assignment (expression has type Callable[[str], None], base class "A" defined the type as Callable[[int], None]) + +[case testTempNode] +class A(): + def a(self) -> None: pass +class B(A): + def b(self) -> None: pass + a = c = b + +[case testListObject] +from typing import List +class A: + x = [] # type: List[object] +class B(A): + x = [1] +[builtins fixtures/list.pyi] + +[case testClassMemberObject] +class A: + x = object() +class B(A): + x = 1 +class C(B): + x = '' +[out] +main:6: error: Incompatible types in assignment (expression has type "str", base class "B" defined the type as "int") + +[case testSlots] +class A: + __slots__ = ("a") +class B(A): + __slots__ = ("a", "b") + +[case testClassOrderOfError] +class A: + x = 1 +class B(A): + x = "a" +class C(B): + x = object() +[out] +main:4: error: Incompatible types in assignment (expression has type "str", base class "A" defined the type as "int") +main:6: error: Incompatible types in assignment (expression has type "object", base class "B" defined the type as "str") + +[case testClassOneErrorPerLine] +class A: + x = 1 +class B(A): + x = "" + x = 1.0 +[out] +main:4: error: Incompatible types in assignment (expression has type "str", base class "A" defined the type as "int") +main:5: error: Incompatible types in assignment (expression has type "str", base class "A" defined the type as "int") + +[case testClassIgnoreType] +class A: + x = 0 +class B(A): + x = '' # type: ignore +class C(B): + x = '' +[out]