diff --git a/mypy/checker.py b/mypy/checker.py index c012251dad9f..9adbff2923b7 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -55,7 +55,7 @@ ) from mypy import message_registry from mypy.subtypes import ( - is_subtype, is_equivalent, is_proper_subtype, is_more_precise, + covers_at_runtime, is_subtype, is_equivalent, is_proper_subtype, is_more_precise, restrict_subtype_away, is_subtype_ignoring_tvars, is_callable_compatible, unify_generic_callable, find_member ) @@ -3986,11 +3986,34 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM if len(node.args) != 2: # the error will be reported elsewhere return {}, {} if literal(expr) == LITERAL_TYPE: - return self.conditional_type_map_with_intersection( + if_map_base, else_map_base = self.conditional_type_map_with_intersection( expr, type_map[expr], get_isinstance_type(node.args[1], type_map), ) + first_node_arg = node.args[0] + if isinstance(first_node_arg, IndexExpr): + arg_type = get_proper_type(type_map[first_node_arg.base]) + if (isinstance(arg_type, UnionType) and + if_map_base is not None and else_map_base is not None): + if_branch_union = [] + else_branch_union = [] + t = if_map_base[expr] + for x in arg_type.items: + x = get_proper_type(x) + if not isinstance(x, Instance) or not x.args: + return if_map_base, else_map_base + if (is_overlapping_types(x.args[0], t) is not + covers_at_runtime(x.args[0], t, False)): + if_branch_union.append(x) + else_branch_union.append(x) + elif is_overlapping_types(x.args[0], t): + if_branch_union.append(x) + else: + else_branch_union.append(x) + if_map_base[first_node_arg.base] = UnionType(if_branch_union) + else_map_base[first_node_arg.base] = UnionType(else_branch_union) + return if_map_base, else_map_base elif refers_to_fullname(node.callee, 'builtins.issubclass'): if len(node.args) != 2: # the error will be reported elsewhere return {}, {} diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index a1d9685cc43d..3180e40c077f 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1026,3 +1026,79 @@ else: reveal_type(str_or_bool_literal) # N: Revealed type is 'Union[Literal[False], Literal[True]]' [builtins fixtures/primitives.pyi] + +[case testNarrowingUnionListTypes] +from typing import Union, List, Any, Sequence + +a: Union[List[List[int]], List[int]] +if isinstance(a[0], list): + reveal_type(a) # N: Revealed type is 'builtins.list[builtins.list[builtins.int]]' +else: + reveal_type(a) # N: Revealed type is 'builtins.list[builtins.int]' + +b: Union[List[str], List[int], List[bool]] +if isinstance(b[0], str): + reveal_type(b) # N: Revealed type is 'builtins.list[builtins.str]' +else: + reveal_type(b) # N: Revealed type is 'Union[builtins.list[builtins.int], builtins.list[builtins.bool]]' + +c: Union[List[Union[str, int]], List[int], List[str]] +if isinstance(c[0], str): + reveal_type(c) # N: Revealed type is 'Union[builtins.list[Union[builtins.str, builtins.int]], builtins.list[builtins.str]]' +else: + reveal_type(c) # N: Revealed type is 'Union[builtins.list[Union[builtins.str, builtins.int]], builtins.list[builtins.int]]' + +d: Union[List[str], List[int]] +if isinstance(d[0], str): + reveal_type(d) # N: Revealed type is 'builtins.list[builtins.str]' +else: + reveal_type(d) # N: Revealed type is 'builtins.list[builtins.int]' + +if isinstance(d[0], int): + reveal_type(d) # N: Revealed type is 'builtins.list[builtins.int]' +else: + reveal_type(d) # N: Revealed type is 'builtins.list[builtins.str]' + +e: Union[List[object], List[str]] +if isinstance(e[0], str): + reveal_type(e) # N: Revealed type is 'Union[builtins.list[builtins.object], builtins.list[builtins.str]]' +else: + reveal_type(e) # N: Revealed type is 'builtins.list[builtins.object]' + +f: Union[List[Any], List[str]] +if isinstance(f[0], str): + reveal_type(f) # N: Revealed type is 'Union[builtins.list[Any], builtins.list[builtins.str]]' +else: + reveal_type(f) # N: Revealed type is 'Union[builtins.list[Any], builtins.list[builtins.str]]' + +h: Union[List[int], Any] +if isinstance(h[0], int): + reveal_type(h[0]) # N: Revealed type is 'builtins.int*' + reveal_type(h) # N: Revealed type is 'Union[builtins.list[builtins.int], Any]' +else: + reveal_type(h[0]) # N: Revealed type is 'Any' + reveal_type(h) # N: Revealed type is 'Union[builtins.list[builtins.int], Any]' + +i: Union[List[int], Sequence[str]] +if isinstance(i[0], int): + reveal_type(i) # N: Revealed type is 'builtins.list[builtins.int]' +else: + reveal_type(i) # N: Revealed type is 'typing.Sequence[builtins.str]' + +q: Union[List[int], List[str], List[List[object]]] +if isinstance(q[0], (str, int)): + reveal_type(q[0]) # N: Revealed type is 'Union[builtins.int*, builtins.str*]' + reveal_type(q) # N: Revealed type is 'Union[builtins.list[builtins.int], builtins.list[builtins.str]]' +else: + reveal_type(q[0]) # N: Revealed type is 'builtins.list*[builtins.object]' + reveal_type(q) # N: Revealed type is 'builtins.list[builtins.list[builtins.object]]' + +g: Union[List[str], List[int]] +if isinstance(g[0], (int, str)): + reveal_type(g) # N: Revealed type is 'Union[builtins.list[builtins.str], builtins.list[builtins.int]]' +else: + reveal_type(g) + +[builtins fixtures/list.pyi] +[builtins fixtures/isinstancelist.pyi] +[typing fixtures/typing-full.pyi]