Skip to content

Commit b759293

Browse files
TrueBraingvanrossum
authored andcommitted
For class variables, lookup type in base classes (#1338, #2022, #2211) (#2510)
Continuation of #2380 (which was reverted). Fixes #2503.
1 parent 92f28ba commit b759293

File tree

6 files changed

+472
-23
lines changed

6 files changed

+472
-23
lines changed

mypy/checker.py

Lines changed: 139 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
YieldFromExpr, NamedTupleExpr, TypedDictExpr, SetComprehension,
2525
DictionaryComprehension, ComplexExpr, EllipsisExpr, TypeAliasExpr,
2626
RefExpr, YieldExpr, BackquoteExpr, ImportFrom, ImportAll, ImportBase,
27-
AwaitExpr, PromoteExpr,
28-
ARG_POS,
27+
AwaitExpr, PromoteExpr, Node,
28+
ARG_POS, MDEF,
2929
CONTRAVARIANT, COVARIANT)
3030
from mypy import nodes
3131
from mypy.types import (
@@ -44,7 +44,8 @@
4444
restrict_subtype_away, is_subtype_ignoring_tvars
4545
)
4646
from mypy.maptype import map_instance_to_supertype
47-
from mypy.semanal import fill_typevars, set_callable_name, refers_to_fullname
47+
from mypy.typevars import fill_typevars, has_no_typevars
48+
from mypy.semanal import set_callable_name, refers_to_fullname
4849
from mypy.erasetype import erase_typevars
4950
from mypy.expandtype import expand_type, expand_type_by_instance
5051
from mypy.visitor import NodeVisitor
@@ -1102,6 +1103,12 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type
11021103
infer_lvalue_type)
11031104
else:
11041105
lvalue_type, index_lvalue, inferred = self.check_lvalue(lvalue)
1106+
1107+
if isinstance(lvalue, NameExpr):
1108+
if self.check_compatibility_all_supers(lvalue, lvalue_type, rvalue):
1109+
# We hit an error on this line; don't check for any others
1110+
return
1111+
11051112
if lvalue_type:
11061113
if isinstance(lvalue_type, PartialType) and lvalue_type.type is None:
11071114
# Try to infer a proper type for a variable with a partial None type.
@@ -1156,6 +1163,123 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type
11561163
self.infer_variable_type(inferred, lvalue, self.accept(rvalue),
11571164
rvalue)
11581165

1166+
def check_compatibility_all_supers(self, lvalue: NameExpr, lvalue_type: Type,
1167+
rvalue: Expression) -> bool:
1168+
lvalue_node = lvalue.node
1169+
1170+
# Check if we are a class variable with at least one base class
1171+
if (isinstance(lvalue_node, Var) and
1172+
lvalue.kind == MDEF and
1173+
len(lvalue_node.info.bases) > 0):
1174+
1175+
for base in lvalue_node.info.mro[1:]:
1176+
# Only check __slots__ against the 'object'
1177+
# If a base class defines a Tuple of 3 elements, a child of
1178+
# this class should not be allowed to define it as a Tuple of
1179+
# anything other than 3 elements. The exception to this rule
1180+
# is __slots__, where it is allowed for any child class to
1181+
# redefine it.
1182+
if lvalue_node.name() == "__slots__" and base.fullname() != "builtins.object":
1183+
continue
1184+
1185+
base_type, base_node = self.lvalue_type_from_base(lvalue_node, base)
1186+
1187+
if base_type:
1188+
if not self.check_compatibility_super(lvalue,
1189+
lvalue_type,
1190+
rvalue,
1191+
base,
1192+
base_type,
1193+
base_node):
1194+
# Only show one error per variable; even if other
1195+
# base classes are also incompatible
1196+
return True
1197+
break
1198+
return False
1199+
1200+
def check_compatibility_super(self, lvalue: NameExpr, lvalue_type: Type, rvalue: Expression,
1201+
base: TypeInfo, base_type: Type, base_node: Node) -> bool:
1202+
lvalue_node = lvalue.node
1203+
assert isinstance(lvalue_node, Var)
1204+
1205+
# Do not check whether the rvalue is compatible if the
1206+
# lvalue had a type defined; this is handled by other
1207+
# parts, and all we have to worry about in that case is
1208+
# that lvalue is compatible with the base class.
1209+
compare_node = None # type: Node
1210+
if lvalue_type:
1211+
compare_type = lvalue_type
1212+
compare_node = lvalue.node
1213+
else:
1214+
compare_type = self.accept(rvalue, base_type)
1215+
if isinstance(rvalue, NameExpr):
1216+
compare_node = rvalue.node
1217+
if isinstance(compare_node, Decorator):
1218+
compare_node = compare_node.func
1219+
1220+
if compare_type:
1221+
if (isinstance(base_type, CallableType) and
1222+
isinstance(compare_type, CallableType)):
1223+
base_static = is_node_static(base_node)
1224+
compare_static = is_node_static(compare_node)
1225+
1226+
# In case compare_static is unknown, also check
1227+
# if 'definition' is set. The most common case for
1228+
# this is with TempNode(), where we lose all
1229+
# information about the real rvalue node (but only get
1230+
# the rvalue type)
1231+
if compare_static is None and compare_type.definition:
1232+
compare_static = is_node_static(compare_type.definition)
1233+
1234+
# Compare against False, as is_node_static can return None
1235+
if base_static is False and compare_static is False:
1236+
# Class-level function objects and classmethods become bound
1237+
# methods: the former to the instance, the latter to the
1238+
# class
1239+
base_type = bind_self(base_type, self.scope.active_class())
1240+
compare_type = bind_self(compare_type, self.scope.active_class())
1241+
1242+
# If we are a static method, ensure to also tell the
1243+
# lvalue it now contains a static method
1244+
if base_static and compare_static:
1245+
lvalue_node.is_staticmethod = True
1246+
1247+
return self.check_subtype(compare_type, base_type, lvalue,
1248+
messages.INCOMPATIBLE_TYPES_IN_ASSIGNMENT,
1249+
'expression has type',
1250+
'base class "%s" defined the type as' % base.name())
1251+
return True
1252+
1253+
def lvalue_type_from_base(self, expr_node: Var,
1254+
base: TypeInfo) -> Tuple[Optional[Type], Optional[Node]]:
1255+
"""For a NameExpr that is part of a class, walk all base classes and try
1256+
to find the first class that defines a Type for the same name."""
1257+
expr_name = expr_node.name()
1258+
base_var = base.names.get(expr_name)
1259+
1260+
if base_var:
1261+
base_node = base_var.node
1262+
base_type = base_var.type
1263+
if isinstance(base_node, Decorator):
1264+
base_node = base_node.func
1265+
base_type = base_node.type
1266+
1267+
if base_type:
1268+
if not has_no_typevars(base_type):
1269+
instance = cast(Instance, self.scope.active_class())
1270+
itype = map_instance_to_supertype(instance, base)
1271+
base_type = expand_type_by_instance(base_type, itype)
1272+
1273+
if isinstance(base_type, CallableType) and isinstance(base_node, FuncDef):
1274+
# If we are a property, return the Type of the return
1275+
# value, not the Callable
1276+
if base_node.is_property:
1277+
base_type = base_type.ret_type
1278+
1279+
return base_type, base_node
1280+
1281+
return None, None
1282+
11591283
def check_assignment_to_multiple_lvalues(self, lvalues: List[Lvalue], rvalue: Expression,
11601284
context: Context,
11611285
infer_lvalue_type: bool = True) -> None:
@@ -2875,6 +2999,18 @@ def is_valid_inferred_type_component(typ: Type) -> bool:
28752999
return True
28763000

28773001

3002+
def is_node_static(node: Node) -> Optional[bool]:
3003+
"""Find out if a node describes a static function method."""
3004+
3005+
if isinstance(node, FuncDef):
3006+
return node.is_static
3007+
3008+
if isinstance(node, Var):
3009+
return node.is_staticmethod
3010+
3011+
return None
3012+
3013+
28783014
class Scope:
28793015
# We keep two stacks combined, to maintain the relative order
28803016
stack = None # type: List[Union[Type, FuncItem, MypyFile]]

mypy/checkexpr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from mypy.checkstrformat import StringFormatterChecker
4141
from mypy.expandtype import expand_type, expand_type_by_instance, freshen_function_type_vars
4242
from mypy.util import split_module_names
43-
from mypy.semanal import fill_typevars
43+
from mypy.typevars import fill_typevars
4444
from mypy.visitor import ExpressionVisitor
4545

4646
from mypy import experiments

mypy/checkmember.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from mypy.maptype import map_instance_to_supertype
1717
from mypy.expandtype import expand_type_by_instance, expand_type, freshen_function_type_vars
1818
from mypy.infer import infer_type_arguments
19-
from mypy.semanal import fill_typevars
19+
from mypy.typevars import fill_typevars
2020
from mypy import messages
2121
from mypy import subtypes
2222
MYPY = False

mypy/semanal.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
IntExpr, FloatExpr, UnicodeExpr, EllipsisExpr, TempNode,
6868
COVARIANT, CONTRAVARIANT, INVARIANT, UNBOUND_IMPORTED, LITERAL_YES,
6969
)
70+
from mypy.typevars import has_no_typevars, fill_typevars
7071
from mypy.visitor import NodeVisitor
7172
from mypy.traverser import TraverserVisitor
7273
from mypy.errors import Errors, report_internal_error
@@ -79,7 +80,6 @@
7980
from mypy.typeanal import TypeAnalyser, TypeAnalyserPass3, analyze_type_alias
8081
from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError
8182
from mypy.sametypes import is_same_type
82-
from mypy.erasetype import erase_typevars
8383
from mypy.options import Options
8484
from mypy import join
8585

@@ -3204,19 +3204,6 @@ def builtin_type(self, name: str, args: List[Type] = None) -> Instance:
32043204
return Instance(sym.node, args or [])
32053205

32063206

3207-
def fill_typevars(typ: TypeInfo) -> Union[Instance, TupleType]:
3208-
"""For a non-generic type, return instance type representing the type.
3209-
For a generic G type with parameters T1, .., Tn, return G[T1, ..., Tn].
3210-
"""
3211-
tv = [] # type: List[Type]
3212-
for i in range(len(typ.type_vars)):
3213-
tv.append(TypeVarType(typ.defn.type_vars[i]))
3214-
inst = Instance(typ, tv)
3215-
if typ.tuple_type is None:
3216-
return inst
3217-
return typ.tuple_type.copy_modified(fallback=inst)
3218-
3219-
32203207
def replace_implicit_first_type(sig: FunctionLike, new: Type) -> FunctionLike:
32213208
if isinstance(sig, CallableType):
32223209
return sig.copy_modified(arg_types=[new] + sig.arg_types[1:])
@@ -3588,7 +3575,3 @@ def find_fixed_callable_return(expr: Expression) -> Optional[CallableType]:
35883575
if isinstance(t.ret_type, CallableType):
35893576
return t.ret_type
35903577
return None
3591-
3592-
3593-
def has_no_typevars(typ: Type) -> bool:
3594-
return is_same_type(typ, erase_typevars(typ))

mypy/typevars.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from typing import Union
2+
3+
from mypy.nodes import TypeInfo
4+
5+
from mypy.erasetype import erase_typevars
6+
from mypy.sametypes import is_same_type
7+
from mypy.types import Instance, TypeVarType, TupleType, Type
8+
9+
10+
def fill_typevars(typ: TypeInfo) -> Union[Instance, TupleType]:
11+
"""For a non-generic type, return instance type representing the type.
12+
For a generic G type with parameters T1, .., Tn, return G[T1, ..., Tn].
13+
"""
14+
tv = [] # type: List[Type]
15+
for i in range(len(typ.type_vars)):
16+
tv.append(TypeVarType(typ.defn.type_vars[i]))
17+
inst = Instance(typ, tv)
18+
if typ.tuple_type is None:
19+
return inst
20+
return typ.tuple_type.copy_modified(fallback=inst)
21+
22+
23+
def has_no_typevars(typ: Type) -> bool:
24+
return is_same_type(typ, erase_typevars(typ))

0 commit comments

Comments
 (0)