diff --git a/mypy/binder.py b/mypy/binder.py index 23be259e82cf..11761d962208 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -1,5 +1,6 @@ -from typing import (Dict, List, Set, Iterator, Union) +from typing import (Dict, List, Set, Iterator, Optional, DefaultDict, Tuple, Union) from contextlib import contextmanager +from collections import defaultdict from mypy.types import Type, AnyType, PartialType from mypy.nodes import (Key, Node, Expression, Var, RefExpr, SymbolTableNode) @@ -45,6 +46,7 @@ class A: reveal_type(lst[0].a) # str ``` """ + type_assignments = None # type: Optional[DefaultDict[Expression, List[Tuple[Type, Type]]]] def __init__(self) -> None: # The stack of frames currently used. These map @@ -198,10 +200,20 @@ def get_declaration(self, expr: Node) -> Type: else: return None + @contextmanager + def accumulate_type_assignments(self) -> Iterator[DefaultDict[Expression, + List[Tuple[Type, Type]]]]: + self.type_assignments = defaultdict(list) + yield self.type_assignments + self.type_assignments = None + def assign_type(self, expr: Expression, type: Type, declared_type: Type, restrict_any: bool = False) -> None: + if self.type_assignments is not None: + self.type_assignments[expr].append((type, declared_type)) + return if not expr.literal: return self.invalidate_dependencies(expr) diff --git a/mypy/checker.py b/mypy/checker.py index 6a0a987ad783..1e7c2bed26d8 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -48,7 +48,7 @@ from mypy.erasetype import erase_typevars from mypy.expandtype import expand_type from mypy.visitor import NodeVisitor -from mypy.join import join_types +from mypy.join import join_types, join_type_list from mypy.treetransform import TransformVisitor from mypy.meet import meet_simple, nearest_builtin_ancestor, is_overlapping_types from mypy.binder import ConditionalTypeBinder @@ -1070,8 +1070,15 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type new_syntax: bool = False) -> None: """Type check a single assignment: lvalue = rvalue.""" if isinstance(lvalue, TupleExpr) or isinstance(lvalue, ListExpr): - self.check_assignment_to_multiple_lvalues(lvalue.items, rvalue, lvalue, - infer_lvalue_type) + if isinstance(rvalue, TupleExpr) or isinstance(rvalue, ListExpr): + self.check_multi_assign_literal(lvalue.items, rvalue, lvalue, infer_lvalue_type) + return + # Infer the type of an ordinary rvalue expression. + # TODO maybe elsewhere; redundant + rvalue_type = self.accept(rvalue) + self.check_multi_assign(lvalue.items, rvalue, rvalue_type, lvalue, + undefined_rvalue=False, + infer_lvalue_type=infer_lvalue_type) else: lvalue_type, index_lvalue, inferred = self.check_lvalue(lvalue) if lvalue_type: @@ -1123,40 +1130,31 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type self.infer_variable_type(inferred, lvalue, self.accept(rvalue), rvalue) - def check_assignment_to_multiple_lvalues(self, lvalues: List[Lvalue], rvalue: Expression, - context: Context, - infer_lvalue_type: bool = True) -> None: - if isinstance(rvalue, TupleExpr) or isinstance(rvalue, ListExpr): - # Recursively go into Tuple or List expression rhs instead of - # using the type of rhs, because this allowed more fine grained - # control in cases like: a, b = [int, str] where rhs would get - # type List[object] - - rvalues = rvalue.items - - if self.check_rvalue_count_in_assignment(lvalues, len(rvalues), context): - star_index = next((i for i, lv in enumerate(lvalues) if - isinstance(lv, StarExpr)), len(lvalues)) - - left_lvs = lvalues[:star_index] - star_lv = cast(StarExpr, - lvalues[star_index]) if star_index != len(lvalues) else None - right_lvs = lvalues[star_index + 1:] - - left_rvs, star_rvs, right_rvs = self.split_around_star( - rvalues, star_index, len(lvalues)) - - lr_pairs = list(zip(left_lvs, left_rvs)) - if star_lv: - rv_list = ListExpr(star_rvs) - rv_list.set_line(rvalue.get_line()) - lr_pairs.append((star_lv.expr, rv_list)) - lr_pairs.extend(zip(right_lvs, right_rvs)) - - for lv, rv in lr_pairs: - self.check_assignment(lv, rv, infer_lvalue_type) - else: - self.check_multi_assignment(lvalues, rvalue, context, infer_lvalue_type) + def check_multi_assign_literal(self, lvalues: List[Lvalue], + rvalue: Union[ListExpr, TupleExpr], + context: Context, infer_lvalue_type: bool = True) -> None: + # Recursively go into Tuple or List expression rhs instead of + # using the type of rhs, because this allowed more fine grained + # control in cases like: a, b = [int, str] where rhs would get + # type List[object] + # Tuple is also special cased to handle mutually nested lists and tuples + rvalues = rvalue.items + if self.check_rvalue_count_in_assignment(lvalues, len(rvalues), context): + star_index = next((i for (i, lv) in enumerate(lvalues) if isinstance(lv, StarExpr)), + len(lvalues)) + left_lvs = lvalues[:star_index] + star_lv = cast(StarExpr, lvalues[star_index]) if star_index != len(lvalues) else None + right_lvs = lvalues[star_index + 1:] + left_rvs, star_rvs, right_rvs = self.split_around_star( + rvalues, star_index, len(lvalues)) + lr_pairs = list(zip(left_lvs, left_rvs)) + if star_lv: + rv_list = ListExpr(star_rvs) + rv_list.set_line(rvalue.get_line()) + lr_pairs.append((star_lv.expr, rv_list)) + lr_pairs.extend(zip(right_lvs, right_rvs)) + for lv, rv in lr_pairs: + self.check_assignment(lv, rv, infer_lvalue_type) def check_rvalue_count_in_assignment(self, lvalues: List[Lvalue], rvalue_count: int, context: Context) -> bool: @@ -1171,33 +1169,64 @@ def check_rvalue_count_in_assignment(self, lvalues: List[Lvalue], rvalue_count: return False return True - def check_multi_assignment(self, lvalues: List[Lvalue], - rvalue: Expression, - context: Context, - infer_lvalue_type: bool = True, - msg: str = None) -> None: - """Check the assignment of one rvalue to a number of lvalues.""" - - # Infer the type of an ordinary rvalue expression. - rvalue_type = self.accept(rvalue) # TODO maybe elsewhere; redundant - undefined_rvalue = False - + def check_multi_assign_from_any(self, lvalues: List[Expression], rvalue: Expression, + rvalue_type: AnyType, context: Context, + undefined_rvalue: bool, + infer_lvalue_type: bool) -> None: + for lv in lvalues: + if isinstance(lv, StarExpr): + lv = lv.expr + self.check_assignment(lv, self.temp_node(AnyType(), context), infer_lvalue_type) + + def check_multi_assign(self, lvalues: List[Lvalue], rvalue: Expression, + rvalue_type: Type, context: Context, *, + undefined_rvalue: bool = False, + infer_lvalue_type: bool = True) -> None: if isinstance(rvalue_type, AnyType): - for lv in lvalues: - if isinstance(lv, StarExpr): - lv = lv.expr - self.check_assignment(lv, self.temp_node(AnyType(), context), infer_lvalue_type) + self.check_multi_assign_from_any(lvalues, rvalue, rvalue_type, + context, undefined_rvalue, infer_lvalue_type) elif isinstance(rvalue_type, TupleType): - self.check_multi_assignment_from_tuple(lvalues, rvalue, rvalue_type, + self.check_multi_assign_from_tuple(lvalues, rvalue, rvalue_type, + context, undefined_rvalue, infer_lvalue_type) + elif isinstance(rvalue_type, UnionType): + self.check_multi_assign_from_union(lvalues, rvalue, rvalue_type, + context, undefined_rvalue, infer_lvalue_type) + elif isinstance(rvalue_type, Instance) and self.instance_is_iterable(rvalue_type): + self.check_multi_assign_from_iterable(lvalues, rvalue, rvalue_type, context, undefined_rvalue, infer_lvalue_type) else: - self.check_multi_assignment_from_iterable(lvalues, rvalue_type, - context, infer_lvalue_type) + self.msg.type_not_iterable(rvalue_type, context) + + def check_multi_assign_from_union(self, lvalues: List[Expression], rvalue: Expression, + rvalue_type: UnionType, context: Context, + undefined_rvalue: bool, + infer_lvalue_type: bool) -> None: + transposed = tuple([] for _ in lvalues) # type: Tuple[List[Type], ...] + with self.binder.accumulate_type_assignments() as assignments: + for item in rvalue_type.items: + self.check_multi_assign(lvalues, rvalue, item, context, + undefined_rvalue=True, + infer_lvalue_type=infer_lvalue_type) + for t, lv in zip(transposed, lvalues): + t.append(self.type_map.get(lv, AnyType())) + union_types = tuple(join_type_list(col) for col in transposed) + for expr, items in assignments.items(): + types, declared_types = zip(*items) + self.binder.assign_type(expr, + join_type_list(types), + join_type_list(declared_types), + False) + for union, lv in zip(union_types, lvalues): + _1, _2, inferred = self.check_lvalue(lv) + if inferred: + self.set_inferred_type(inferred, lv, union) + else: + self.store_type(lv, union) - def check_multi_assignment_from_tuple(self, lvalues: List[Lvalue], rvalue: Expression, - rvalue_type: TupleType, context: Context, - undefined_rvalue: bool, - infer_lvalue_type: bool = True) -> None: + def check_multi_assign_from_tuple(self, lvalues: List[Lvalue], rvalue: Expression, + rvalue_type: TupleType, context: Context, + undefined_rvalue: bool, + infer_lvalue_type: bool) -> None: if self.check_rvalue_count_in_assignment(lvalues, len(rvalue_type.items), context): star_index = next((i for i, lv in enumerate(lvalues) if isinstance(lv, StarExpr)), len(lvalues)) @@ -1275,25 +1304,22 @@ def split_around_star(self, items: List[T], star_index: int, right = items[right_index:] return (left, star, right) - def type_is_iterable(self, type: Type) -> bool: - return (is_subtype(type, self.named_generic_type('typing.Iterable', - [AnyType()])) and - isinstance(type, Instance)) - - def check_multi_assignment_from_iterable(self, lvalues: List[Lvalue], rvalue_type: Type, - context: Context, - infer_lvalue_type: bool = True) -> None: - if self.type_is_iterable(rvalue_type): - item_type = self.iterable_item_type(cast(Instance, rvalue_type)) - for lv in lvalues: - if isinstance(lv, StarExpr): - self.check_assignment(lv.expr, self.temp_node(rvalue_type, context), - infer_lvalue_type) - else: - self.check_assignment(lv, self.temp_node(item_type, context), - infer_lvalue_type) - else: - self.msg.type_not_iterable(rvalue_type, context) + def instance_is_iterable(self, instance: Instance) -> bool: + return is_subtype(instance, self.named_generic_type('typing.Iterable', + [AnyType()])) + + def check_multi_assign_from_iterable(self, lvalues: List[Expression], rvalue: Expression, + rvalue_type: Instance, context: Context, + undefined_rvalue: bool, + infer_lvalue_type: bool) -> None: + item_type = self.iterable_item_type(rvalue_type) + for lv in lvalues: + if isinstance(lv, StarExpr): + self.check_assignment(lv.expr, self.temp_node(rvalue_type, context), + infer_lvalue_type) + else: + self.check_assignment(lv, self.temp_node(item_type, context), + infer_lvalue_type) def check_lvalue(self, lvalue: Lvalue) -> Tuple[Type, IndexExpr, Var]: lvalue_type = None # type: Type diff --git a/test-data/unit/check-unions.test b/test-data/unit/check-unions.test index 121413836014..cc396d734de4 100644 --- a/test-data/unit/check-unions.test +++ b/test-data/unit/check-unions.test @@ -129,3 +129,123 @@ class C(Generic[T, U]): a = C() # type: C[int, int] b = a.f('a') a.f(b) # E: Argument 1 to "f" of "C" has incompatible type "int"; expected "str" + +[case testUnionMultiassign1] +from typing import Union, Tuple, Any + +b = None # type: Union[Tuple[int], Tuple[float]] +(b1,) = b +reveal_type(b1) # E: Revealed type is 'builtins.float' + +[case testUnionMultiassign2] +from typing import Union, Tuple + +c = None # type: Union[Tuple[int, int], Tuple[int, str]] +q = None # type: Tuple[int, float] +(c1, c2) = c +reveal_type(c1) # E: Revealed type is 'builtins.int' +reveal_type(c2) # E: Revealed type is 'builtins.object' + +[case testUnionMultiassignAny] +from typing import Union, Tuple, Any + +d = None # type: Union[Any, Tuple[float, float]] +(d1, d2) = d +reveal_type(d1) # E: Revealed type is 'Any' +reveal_type(d2) # E: Revealed type is 'Any' + +e = None # type: Union[Any, Tuple[float, float], int] +(e1, e2) = e # E: 'builtins.int' object is not iterable + +[case testUnionMultiassignJoin] +from typing import Union, List + +class A: pass +class B(A): pass +class C(A): pass +a = None # type: Union[List[B], List[C]] +x, y = a +reveal_type(x) # E: Revealed type is '__main__.A' + +[builtins fixtures/list.pyi] +[case testUnionMultiassignRebind] +from typing import Union, List + +class A: pass +class B(A): pass +class C(A): pass +obj = None # type: object +c = None # type: object +a = None # type: Union[List[B], List[C]] +obj, new = a +reveal_type(obj) # E: Revealed type is '__main__.A' +reveal_type(new) # E: Revealed type is '__main__.A' + +obj = 1 +reveal_type(obj) # E: Revealed type is 'builtins.int' + +[builtins fixtures/list.pyi] + +[case testUnionMultiassignAlreadyDeclared] +from typing import Union, Tuple + +a = None # type: Union[Tuple[int, int], Tuple[int, float]] +a1 = None # type: object +a2 = None # type: int + +(a1, a2) = a # E: Incompatible types in assignment (expression has type "float", variable has type "int") + +b = None # type: Union[Tuple[float, int], Tuple[int, int]] +b1 = None # type: object +b2 = None # type: int + +(b1, b2) = a # E: Incompatible types in assignment (expression has type "float", variable has type "int") + +c = None # type: Union[Tuple[int, int], Tuple[int, int]] +c1 = None # type: object +c2 = None # type: int + +(c1, c2) = c +reveal_type(c1) # E: Revealed type is 'builtins.int' +reveal_type(c2) # E: Revealed type is 'builtins.int' + +d = None # type: Union[Tuple[int, int], Tuple[int, float]] +d1 = None # type: object + +(d1, d2) = d +reveal_type(d1) # E: Revealed type is 'builtins.int' +reveal_type(d2) # E: Revealed type is 'builtins.float' + +[case testUnionMultiassignIndexed] +from typing import Union, Tuple, List + +class B: + x = None # type: object + +x = None # type: List[int] +b = None # type: B + +a = None # type: Union[Tuple[int, int], Tuple[int, object]] +(x[0], b.x) = a + +# I don't know why is it incomplete type +reveal_type(x[0]) # E: Revealed type is 'builtins.int*' +reveal_type(b.x) # E: Revealed type is 'builtins.object' + +[builtins fixtures/list.pyi] + +[case testUnionMultiassignPacked] +from typing import Union, Tuple, List + +a = None # type: Union[Tuple[int, int, int], Tuple[int, int, str]] +a1 = None # type: int +a2 = None # type: object +--FIX: allow proper rebinding of packed +xs = None # type: List[int] +(a1, *xs, a2) = a + +reveal_type(a1) # E: Revealed type is 'builtins.int' +reveal_type(xs) # E: Revealed type is 'builtins.list[builtins.int]' +reveal_type(a2) # E: Revealed type is 'builtins.int' + +[builtins fixtures/list.pyi] \ No newline at end of file diff --git a/test-data/unit/fixtures/list.pyi b/test-data/unit/fixtures/list.pyi index 2f9893d727bb..c12aa80f7fc3 100644 --- a/test-data/unit/fixtures/list.pyi +++ b/test-data/unit/fixtures/list.pyi @@ -19,6 +19,7 @@ class list(Iterable[T], Generic[T]): def __add__(self, x: list[T]) -> list[T]: pass def __mul__(self, x: int) -> list[T]: pass def __getitem__(self, x: int) -> T: pass + def __setitem__(self, x: int, v: T) -> None: pass def append(self, x: T) -> None: pass def extend(self, x: Iterable[T]) -> None: pass