From cf425bddc94db15db3600cc113c566d2d3e22764 Mon Sep 17 00:00:00 2001 From: Steve Dignam Date: Sun, 30 Aug 2020 16:49:27 -0400 Subject: [PATCH 1/3] add narrowing for lists via index isinstance check Allows for narrowing a union of list types by indexing into the list and checking the type of an element. mypy already narrows the specific index expression. Since mypy already narrows the index expression's type, we use the type to compare against the type arguments of the lists and generate a new type that is narrower than the original. ``` from typing import Union, List a: Union[List[List[int]], List[int]] # before if isinstance(a[0], list): reveal_type(a) # N: Revealed type is 'Union[builtins.list[builtins.list[builtins.int]], builtins.list[builtins.int]]' else: reveal_type(a) # N: Revealed type is 'Union[builtins.list[builtins.list[builtins.int]], builtins.list[builtins.int]]' # after if isinstance(a[0], list): reveal_type(a) # N: Revealed type is 'builtins.list[builtins.list[builtins.int]]' else: reveal_type(a) # N: Revealed type is 'builtins.list[builtins.int]' ``` fixes: https://github.com/python/mypy/issues/9362 --- mypy/checker.py | 26 +++++++++- test-data/unit/check-narrowing.test | 76 +++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 2 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index c012251dad9f..cb42d61df62f 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -55,7 +55,7 @@ ) from mypy import message_registry from mypy.subtypes import ( - is_subtype, is_equivalent, is_proper_subtype, is_more_precise, + covers_at_runtime, is_subtype, is_equivalent, is_proper_subtype, is_more_precise, restrict_subtype_away, is_subtype_ignoring_tvars, is_callable_compatible, unify_generic_callable, find_member ) @@ -3986,11 +3986,33 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM if len(node.args) != 2: # the error will be reported elsewhere return {}, {} if literal(expr) == LITERAL_TYPE: - return self.conditional_type_map_with_intersection( + if_map, else_map = self.conditional_type_map_with_intersection( expr, type_map[expr], get_isinstance_type(node.args[1], type_map), ) + first_node_arg = node.args[0] + if isinstance(first_node_arg, IndexExpr): + arg_type = get_proper_type(type_map[first_node_arg.base]) + if isinstance(arg_type, UnionType) and if_map is not None and else_map is not None: + var_name = first_node_arg.base + if_branch_union = [] + else_branch_union = [] + t = if_map[expr] + for x in arg_type.items: + x = get_proper_type(x) + if not isinstance(x, Instance) or not x.args: + return if_map, else_map + if is_overlapping_types(x.args[0], t) is not covers_at_runtime(x.args[0], t, False): + if_branch_union.append(x) + else_branch_union.append(x) + elif is_overlapping_types(x.args[0], t): + if_branch_union.append(x) + else: + else_branch_union.append(x) + if_map[var_name] = UnionType(if_branch_union) + else_map[var_name] = UnionType(else_branch_union) + return if_map, else_map elif refers_to_fullname(node.callee, 'builtins.issubclass'): if len(node.args) != 2: # the error will be reported elsewhere return {}, {} diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index a1d9685cc43d..3180e40c077f 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1026,3 +1026,79 @@ else: reveal_type(str_or_bool_literal) # N: Revealed type is 'Union[Literal[False], Literal[True]]' [builtins fixtures/primitives.pyi] + +[case testNarrowingUnionListTypes] +from typing import Union, List, Any, Sequence + +a: Union[List[List[int]], List[int]] +if isinstance(a[0], list): + reveal_type(a) # N: Revealed type is 'builtins.list[builtins.list[builtins.int]]' +else: + reveal_type(a) # N: Revealed type is 'builtins.list[builtins.int]' + +b: Union[List[str], List[int], List[bool]] +if isinstance(b[0], str): + reveal_type(b) # N: Revealed type is 'builtins.list[builtins.str]' +else: + reveal_type(b) # N: Revealed type is 'Union[builtins.list[builtins.int], builtins.list[builtins.bool]]' + +c: Union[List[Union[str, int]], List[int], List[str]] +if isinstance(c[0], str): + reveal_type(c) # N: Revealed type is 'Union[builtins.list[Union[builtins.str, builtins.int]], builtins.list[builtins.str]]' +else: + reveal_type(c) # N: Revealed type is 'Union[builtins.list[Union[builtins.str, builtins.int]], builtins.list[builtins.int]]' + +d: Union[List[str], List[int]] +if isinstance(d[0], str): + reveal_type(d) # N: Revealed type is 'builtins.list[builtins.str]' +else: + reveal_type(d) # N: Revealed type is 'builtins.list[builtins.int]' + +if isinstance(d[0], int): + reveal_type(d) # N: Revealed type is 'builtins.list[builtins.int]' +else: + reveal_type(d) # N: Revealed type is 'builtins.list[builtins.str]' + +e: Union[List[object], List[str]] +if isinstance(e[0], str): + reveal_type(e) # N: Revealed type is 'Union[builtins.list[builtins.object], builtins.list[builtins.str]]' +else: + reveal_type(e) # N: Revealed type is 'builtins.list[builtins.object]' + +f: Union[List[Any], List[str]] +if isinstance(f[0], str): + reveal_type(f) # N: Revealed type is 'Union[builtins.list[Any], builtins.list[builtins.str]]' +else: + reveal_type(f) # N: Revealed type is 'Union[builtins.list[Any], builtins.list[builtins.str]]' + +h: Union[List[int], Any] +if isinstance(h[0], int): + reveal_type(h[0]) # N: Revealed type is 'builtins.int*' + reveal_type(h) # N: Revealed type is 'Union[builtins.list[builtins.int], Any]' +else: + reveal_type(h[0]) # N: Revealed type is 'Any' + reveal_type(h) # N: Revealed type is 'Union[builtins.list[builtins.int], Any]' + +i: Union[List[int], Sequence[str]] +if isinstance(i[0], int): + reveal_type(i) # N: Revealed type is 'builtins.list[builtins.int]' +else: + reveal_type(i) # N: Revealed type is 'typing.Sequence[builtins.str]' + +q: Union[List[int], List[str], List[List[object]]] +if isinstance(q[0], (str, int)): + reveal_type(q[0]) # N: Revealed type is 'Union[builtins.int*, builtins.str*]' + reveal_type(q) # N: Revealed type is 'Union[builtins.list[builtins.int], builtins.list[builtins.str]]' +else: + reveal_type(q[0]) # N: Revealed type is 'builtins.list*[builtins.object]' + reveal_type(q) # N: Revealed type is 'builtins.list[builtins.list[builtins.object]]' + +g: Union[List[str], List[int]] +if isinstance(g[0], (int, str)): + reveal_type(g) # N: Revealed type is 'Union[builtins.list[builtins.str], builtins.list[builtins.int]]' +else: + reveal_type(g) + +[builtins fixtures/list.pyi] +[builtins fixtures/isinstancelist.pyi] +[typing fixtures/typing-full.pyi] From a450f7a83f34378283367ba6dba52e22671fc168 Mon Sep 17 00:00:00 2001 From: Steve Dignam Date: Sun, 30 Aug 2020 16:52:34 -0400 Subject: [PATCH 2/3] cut var --- mypy/checker.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index cb42d61df62f..63066630982c 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3995,7 +3995,6 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM if isinstance(first_node_arg, IndexExpr): arg_type = get_proper_type(type_map[first_node_arg.base]) if isinstance(arg_type, UnionType) and if_map is not None and else_map is not None: - var_name = first_node_arg.base if_branch_union = [] else_branch_union = [] t = if_map[expr] @@ -4010,8 +4009,8 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM if_branch_union.append(x) else: else_branch_union.append(x) - if_map[var_name] = UnionType(if_branch_union) - else_map[var_name] = UnionType(else_branch_union) + if_map[first_node_arg.base] = UnionType(if_branch_union) + else_map[first_node_arg.base] = UnionType(else_branch_union) return if_map, else_map elif refers_to_fullname(node.callee, 'builtins.issubclass'): if len(node.args) != 2: # the error will be reported elsewhere From fd85727b9541b21c475d520d2ab8e15ae5ed5f1c Mon Sep 17 00:00:00 2001 From: Steve Dignam Date: Sun, 30 Aug 2020 17:04:32 -0400 Subject: [PATCH 3/3] fix flake8 and mypy --- mypy/checker.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 63066630982c..9adbff2923b7 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3986,7 +3986,7 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM if len(node.args) != 2: # the error will be reported elsewhere return {}, {} if literal(expr) == LITERAL_TYPE: - if_map, else_map = self.conditional_type_map_with_intersection( + if_map_base, else_map_base = self.conditional_type_map_with_intersection( expr, type_map[expr], get_isinstance_type(node.args[1], type_map), @@ -3994,24 +3994,26 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM first_node_arg = node.args[0] if isinstance(first_node_arg, IndexExpr): arg_type = get_proper_type(type_map[first_node_arg.base]) - if isinstance(arg_type, UnionType) and if_map is not None and else_map is not None: + if (isinstance(arg_type, UnionType) and + if_map_base is not None and else_map_base is not None): if_branch_union = [] else_branch_union = [] - t = if_map[expr] + t = if_map_base[expr] for x in arg_type.items: x = get_proper_type(x) if not isinstance(x, Instance) or not x.args: - return if_map, else_map - if is_overlapping_types(x.args[0], t) is not covers_at_runtime(x.args[0], t, False): + return if_map_base, else_map_base + if (is_overlapping_types(x.args[0], t) is not + covers_at_runtime(x.args[0], t, False)): if_branch_union.append(x) else_branch_union.append(x) elif is_overlapping_types(x.args[0], t): if_branch_union.append(x) else: else_branch_union.append(x) - if_map[first_node_arg.base] = UnionType(if_branch_union) - else_map[first_node_arg.base] = UnionType(else_branch_union) - return if_map, else_map + if_map_base[first_node_arg.base] = UnionType(if_branch_union) + else_map_base[first_node_arg.base] = UnionType(else_branch_union) + return if_map_base, else_map_base elif refers_to_fullname(node.callee, 'builtins.issubclass'): if len(node.args) != 2: # the error will be reported elsewhere return {}, {}