diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index f412d95283a6..2047c8308bb0 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1531,55 +1531,68 @@ def check_generator_or_comprehension(self, gen: GeneratorExpr, type_name: str, id_for_messages: str) -> Type: """Type check a generator expression or a list comprehension.""" - self.check_for_comp(gen) + with self.chk.binder.frame_context(): + self.check_for_comp(gen) - # Infer the type of the list comprehension by using a synthetic generic - # callable type. - tvdef = TypeVarDef('T', -1, [], self.chk.object_type()) - tv = TypeVarType(tvdef) - constructor = CallableType( - [tv], - [nodes.ARG_POS], - [None], - self.chk.named_generic_type(type_name, [tv]), - self.chk.named_type('builtins.function'), - name=id_for_messages, - variables=[tvdef]) - return self.check_call(constructor, - [gen.left_expr], [nodes.ARG_POS], gen)[0] + # Infer the type of the list comprehension by using a synthetic generic + # callable type. + tvdef = TypeVarDef('T', -1, [], self.chk.object_type()) + tv = TypeVarType(tvdef) + constructor = CallableType( + [tv], + [nodes.ARG_POS], + [None], + self.chk.named_generic_type(type_name, [tv]), + self.chk.named_type('builtins.function'), + name=id_for_messages, + variables=[tvdef]) + return self.check_call(constructor, + [gen.left_expr], [nodes.ARG_POS], gen)[0] def visit_dictionary_comprehension(self, e: DictionaryComprehension) -> Type: """Type check a dictionary comprehension.""" - self.check_for_comp(e) - - # Infer the type of the list comprehension by using a synthetic generic - # callable type. - ktdef = TypeVarDef('KT', -1, [], self.chk.object_type()) - vtdef = TypeVarDef('VT', -2, [], self.chk.object_type()) - kt = TypeVarType(ktdef) - vt = TypeVarType(vtdef) - constructor = CallableType( - [kt, vt], - [nodes.ARG_POS, nodes.ARG_POS], - [None, None], - self.chk.named_generic_type('builtins.dict', [kt, vt]), - self.chk.named_type('builtins.function'), - name='', - variables=[ktdef, vtdef]) - return self.check_call(constructor, - [e.key, e.value], [nodes.ARG_POS, nodes.ARG_POS], e)[0] + with self.chk.binder.frame_context(): + self.check_for_comp(e) + + # Infer the type of the list comprehension by using a synthetic generic + # callable type. + ktdef = TypeVarDef('KT', -1, [], self.chk.object_type()) + vtdef = TypeVarDef('VT', -2, [], self.chk.object_type()) + kt = TypeVarType(ktdef) + vt = TypeVarType(vtdef) + constructor = CallableType( + [kt, vt], + [nodes.ARG_POS, nodes.ARG_POS], + [None, None], + self.chk.named_generic_type('builtins.dict', [kt, vt]), + self.chk.named_type('builtins.function'), + name='', + variables=[ktdef, vtdef]) + return self.check_call(constructor, + [e.key, e.value], [nodes.ARG_POS, nodes.ARG_POS], e)[0] def check_for_comp(self, e: Union[GeneratorExpr, DictionaryComprehension]) -> None: """Check the for_comp part of comprehensions. That is the part from 'for': ... for x in y if z + + Note: This adds the type information derived from the condlists to the current binder. """ - with self.chk.binder.frame_context(): - for index, sequence, conditions in zip(e.indices, e.sequences, - e.condlists): - sequence_type = self.chk.analyze_iterable_item_type(sequence) - self.chk.analyze_index_variables(index, sequence_type, e) - for condition in conditions: - self.accept(condition) + for index, sequence, conditions in zip(e.indices, e.sequences, + e.condlists): + sequence_type = self.chk.analyze_iterable_item_type(sequence) + self.chk.analyze_index_variables(index, sequence_type, e) + for condition in conditions: + self.accept(condition) + + # values are only part of the comprehension when all conditions are true + true_map, _ = mypy.checker.find_isinstance_check( + condition, self.chk.type_map, + self.chk.typing_mode_weak() + ) + + if true_map: + for var, type in true_map.items(): + self.chk.binder.push(var, type) def visit_conditional_expr(self, e: ConditionalExpr) -> Type: cond_type = self.accept(e.cond) diff --git a/mypy/nodes.py b/mypy/nodes.py index bc25b8441045..07c9282a09c8 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1504,7 +1504,7 @@ class GeneratorExpr(Expression): """Generator expression ... for ... in ... [ for ... in ... ] [ if ... ].""" left_expr = None # type: Expression - sequences_expr = None # type: List[Expression] + sequences = None # type: List[Expression] condlists = None # type: List[List[Expression]] indices = None # type: List[Expression] @@ -1548,7 +1548,7 @@ class DictionaryComprehension(Expression): key = None # type: Expression value = None # type: Expression - sequences_expr = None # type: List[Expression] + sequences = None # type: List[Expression] condlists = None # type: List[List[Expression]] indices = None # type: List[Expression] diff --git a/test-data/unit/check-isinstance.test b/test-data/unit/check-isinstance.test index 26736b2f92a1..0bb364b326c7 100644 --- a/test-data/unit/check-isinstance.test +++ b/test-data/unit/check-isinstance.test @@ -1155,3 +1155,13 @@ else: 1() [builtins fixtures/isinstance.py] [out] +[case testComprehensionIsInstance] +from typing import List, Union +a = [] # type: List[Union[int, str]] +l = [x for x in a if isinstance(x, int)] +g = (x for x in a if isinstance(x, int)) +d = {0: x for x in a if isinstance(x, int)} +reveal_type(l) # E: Revealed type is 'builtins.list[builtins.int*]' +reveal_type(g) # E: Revealed type is 'typing.Iterator[builtins.int*]' +reveal_type(d) # E: Revealed type is 'builtins.dict[builtins.int*, builtins.int*]' +[builtins fixtures/isinstancelist.py] diff --git a/test-data/unit/fixtures/isinstancelist.py b/test-data/unit/fixtures/isinstancelist.py index 99b0c209b89c..4b3569875ab9 100644 --- a/test-data/unit/fixtures/isinstancelist.py +++ b/test-data/unit/fixtures/isinstancelist.py @@ -1,4 +1,4 @@ -from typing import builtinclass, Iterable, Iterator, Generic, TypeVar, List +from typing import builtinclass, Iterable, Iterator, Generic, TypeVar, List, Mapping, overload, Tuple @builtinclass class object: @@ -24,6 +24,8 @@ def __add__(self, x: str) -> str: pass def __getitem__(self, x: int) -> str: pass T = TypeVar('T') +KT = TypeVar('KT') +VT = TypeVar('VT') class list(Iterable[T], Generic[T]): def __iter__(self) -> Iterator[T]: pass @@ -31,3 +33,12 @@ def __mul__(self, x: int) -> list[T]: pass def __setitem__(self, x: int, v: T) -> None: pass def __getitem__(self, x: int) -> T: pass def __add__(self, x: List[T]) -> T: pass + +class dict(Iterable[KT], Mapping[KT, VT], Generic[KT, VT]): + @overload + def __init__(self, **kwargs: VT) -> None: pass + @overload + def __init__(self, arg: Iterable[Tuple[KT, VT]], **kwargs: VT) -> None: pass + def __setitem__(self, k: KT, v: VT) -> None: pass + def __iter__(self) -> Iterator[KT]: pass + def update(self, a: Mapping[KT, VT]) -> None: pass