Skip to content

Commit cdfc272

Browse files
authored
Fine-grained: Support async def/for/with and await (#4969)
Some of the changes motivated cleaning up some duplicate error messages. Also noticed missing __next__/next dependency in ordinary for statements and fixed it as well. Work towards #4951.
1 parent 498a9cc commit cdfc272

12 files changed

+245
-65
lines changed

mypy/checker.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2517,57 +2517,47 @@ def get_types_from_except_handler(self, typ: Type, n: Expression) -> List[Type]:
25172517
def visit_for_stmt(self, s: ForStmt) -> None:
25182518
"""Type check a for statement."""
25192519
if s.is_async:
2520-
item_type = self.analyze_async_iterable_item_type(s.expr)
2520+
iterator_type, item_type = self.analyze_async_iterable_item_type(s.expr)
25212521
else:
2522-
item_type = self.analyze_iterable_item_type(s.expr)
2522+
iterator_type, item_type = self.analyze_iterable_item_type(s.expr)
25232523
s.inferred_item_type = item_type
2524+
s.inferred_iterator_type = iterator_type
25242525
self.analyze_index_variables(s.index, item_type, s.index_type is None, s)
25252526
self.accept_loop(s.body, s.else_body)
25262527

2527-
def analyze_async_iterable_item_type(self, expr: Expression) -> Type:
2528-
"""Analyse async iterable expression and return iterator item type."""
2528+
def analyze_async_iterable_item_type(self, expr: Expression) -> Tuple[Type, Type]:
2529+
"""Analyse async iterable expression and return iterator and iterator item types."""
25292530
echk = self.expr_checker
25302531
iterable = echk.accept(expr)
2531-
2532-
self.check_subtype(iterable,
2533-
self.named_generic_type('typing.AsyncIterable',
2534-
[AnyType(TypeOfAny.special_form)]),
2535-
expr, messages.ASYNC_ITERABLE_EXPECTED)
2536-
25372532
method = echk.analyze_external_member_access('__aiter__', iterable, expr)
25382533
iterator = echk.check_call(method, [], [], expr)[0]
25392534
method = echk.analyze_external_member_access('__anext__', iterator, expr)
25402535
awaitable = echk.check_call(method, [], [], expr)[0]
2541-
return echk.check_awaitable_expr(awaitable, expr,
2542-
messages.INCOMPATIBLE_TYPES_IN_ASYNC_FOR)
2536+
item_type = echk.check_awaitable_expr(awaitable, expr,
2537+
messages.INCOMPATIBLE_TYPES_IN_ASYNC_FOR)
2538+
return iterator, item_type
25432539

2544-
def analyze_iterable_item_type(self, expr: Expression) -> Type:
2545-
"""Analyse iterable expression and return iterator item type."""
2540+
def analyze_iterable_item_type(self, expr: Expression) -> Tuple[Type, Type]:
2541+
"""Analyse iterable expression and return iterator and iterator item types."""
25462542
echk = self.expr_checker
25472543
iterable = echk.accept(expr)
2544+
method = echk.analyze_external_member_access('__iter__', iterable, expr)
2545+
iterator = echk.check_call(method, [], [], expr)[0]
25482546

25492547
if isinstance(iterable, TupleType):
25502548
joined = UninhabitedType() # type: Type
25512549
for item in iterable.items:
25522550
joined = join_types(joined, item)
2553-
return joined
2551+
return iterator, joined
25542552
else:
25552553
# Non-tuple iterable.
2556-
self.check_subtype(iterable,
2557-
self.named_generic_type('typing.Iterable',
2558-
[AnyType(TypeOfAny.special_form)]),
2559-
expr, messages.ITERABLE_EXPECTED)
2560-
2561-
method = echk.analyze_external_member_access('__iter__', iterable,
2562-
expr)
2563-
iterator = echk.check_call(method, [], [], expr)[0]
25642554
if self.options.python_version[0] >= 3:
25652555
nextmethod = '__next__'
25662556
else:
25672557
nextmethod = 'next'
25682558
method = echk.analyze_external_member_access(nextmethod, iterator,
25692559
expr)
2570-
return echk.check_call(method, [], [], expr)[0]
2560+
return iterator, echk.check_call(method, [], [], expr)[0]
25712561

25722562
def analyze_index_variables(self, index: Expression, item_type: Type,
25732563
infer_lvalue_type: bool, context: Context) -> None:

mypy/checkexpr.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,7 +1422,7 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
14221422
elif (local_errors.is_errors() and
14231423
# is_valid_var_arg is True for any Iterable
14241424
self.is_valid_var_arg(right_type)):
1425-
itertype = self.chk.analyze_iterable_item_type(right)
1425+
_, itertype = self.chk.analyze_iterable_item_type(right)
14261426
method_type = CallableType(
14271427
[left_type],
14281428
[nodes.ARG_POS],
@@ -2290,9 +2290,9 @@ def check_for_comp(self, e: Union[GeneratorExpr, DictionaryComprehension]) -> No
22902290
for index, sequence, conditions, is_async in zip(e.indices, e.sequences,
22912291
e.condlists, e.is_async):
22922292
if is_async:
2293-
sequence_type = self.chk.analyze_async_iterable_item_type(sequence)
2293+
_, sequence_type = self.chk.analyze_async_iterable_item_type(sequence)
22942294
else:
2295-
sequence_type = self.chk.analyze_iterable_item_type(sequence)
2295+
_, sequence_type = self.chk.analyze_iterable_item_type(sequence)
22962296
self.chk.analyze_index_variables(index, sequence_type, True, e)
22972297
for condition in conditions:
22982298
self.accept(condition)

mypy/messages.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@
6363
MUST_HAVE_NONE_RETURN_TYPE = 'The return type of "{}" must be None'
6464
INVALID_TUPLE_INDEX_TYPE = 'Invalid tuple index type'
6565
TUPLE_INDEX_OUT_OF_RANGE = 'Tuple index out of range'
66-
ITERABLE_EXPECTED = 'Iterable expected'
67-
ASYNC_ITERABLE_EXPECTED = 'AsyncIterable expected'
6866
INVALID_SLICE_INDEX = 'Slice index must be an integer or None'
6967
CANNOT_INFER_LAMBDA_TYPE = 'Cannot infer type of lambda'
7068
CANNOT_INFER_ITEM_TYPE = 'Cannot infer iterable item type'
@@ -450,28 +448,36 @@ def has_no_attr(self, original_type: Type, typ: Type, member: str, context: Cont
450448
self.fail('{} not callable'.format(self.format(original_type)), context)
451449
else:
452450
# The non-special case: a missing ordinary attribute.
451+
extra = ''
452+
if member == '__iter__':
453+
extra = ' (not iterable)'
454+
elif member == '__aiter__':
455+
extra = ' (not async iterable)'
453456
if not self.disable_type_names:
454457
failed = False
455458
if isinstance(original_type, Instance) and original_type.type.names:
456459
alternatives = set(original_type.type.names.keys())
457460
matches = [m for m in COMMON_MISTAKES.get(member, []) if m in alternatives]
458461
matches.extend(best_matches(member, alternatives)[:3])
462+
if member == '__aiter__' and matches == ['__iter__']:
463+
matches = [] # Avoid misleading suggestion
459464
if matches:
460-
self.fail('{} has no attribute "{}"; maybe {}?'.format(
461-
self.format(original_type), member, pretty_or(matches)), context)
465+
self.fail('{} has no attribute "{}"; maybe {}?{}'.format(
466+
self.format(original_type), member, pretty_or(matches), extra),
467+
context)
462468
failed = True
463469
if not failed:
464-
self.fail('{} has no attribute "{}"'.format(self.format(original_type),
465-
member), context)
470+
self.fail('{} has no attribute "{}"{}'.format(self.format(original_type),
471+
member, extra), context)
466472
elif isinstance(original_type, UnionType):
467473
# The checker passes "object" in lieu of "None" for attribute
468474
# checks, so we manually convert it back.
469475
typ_format = self.format(typ)
470476
if typ_format == '"object"' and \
471477
any(type(item) == NoneTyp for item in original_type.items):
472478
typ_format = '"None"'
473-
self.fail('Item {} of {} has no attribute "{}"'.format(
474-
typ_format, self.format(original_type), member), context)
479+
self.fail('Item {} of {} has no attribute "{}"{}'.format(
480+
typ_format, self.format(original_type), member, extra), context)
475481
return AnyType(TypeOfAny.from_error)
476482

477483
def unsupported_operand_types(self, op: str, left_type: Any,

mypy/nodes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -952,6 +952,8 @@ class ForStmt(Statement):
952952
index_type = None # type: Optional[mypy.types.Type]
953953
# Inferred iterable item type
954954
inferred_item_type = None # type: Optional[mypy.types.Type]
955+
# Inferred iterator type
956+
inferred_iterator_type = None # type: Optional[mypy.types.Type]
955957
# Expression to iterate
956958
expr = None # type: Expression
957959
body = None # type: Block

mypy/server/aststrip.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
Node, FuncDef, NameExpr, MemberExpr, RefExpr, MypyFile, FuncItem, ClassDef, AssignmentStmt,
4545
ImportFrom, Import, TypeInfo, SymbolTable, Var, CallExpr, Decorator, OverloadedFuncDef,
4646
SuperExpr, UNBOUND_IMPORTED, GDEF, MDEF, IndexExpr, SymbolTableNode, ImportAll, TupleExpr,
47-
ListExpr
47+
ListExpr, ForStmt
4848
)
4949
from mypy.semanal_shared import create_indirect_imported_name
5050
from mypy.traverser import TraverserVisitor
@@ -217,6 +217,11 @@ def visit_import(self, node: Import) -> None:
217217
symnode.kind = UNBOUND_IMPORTED
218218
symnode.node = None
219219

220+
def visit_for_stmt(self, node: ForStmt) -> None:
221+
node.index_type = None
222+
node.inferred_item_type = None
223+
super().visit_for_stmt(node)
224+
220225
def visit_import_all(self, node: ImportAll) -> None:
221226
# If the node is unreachable, we don't want to reset entries from a reachable import.
222227
if node.is_unreachable:

mypy/server/deps.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class 'mod.Cls'. This can also refer to an attribute inherited from a
9292
ComparisonExpr, GeneratorExpr, DictionaryComprehension, StarExpr, PrintStmt, ForStmt, WithStmt,
9393
TupleExpr, ListExpr, OperatorAssignmentStmt, DelStmt, YieldFromExpr, Decorator, Block,
9494
TypeInfo, FuncBase, OverloadedFuncDef, RefExpr, SuperExpr, Var, NamedTupleExpr, TypedDictExpr,
95-
LDEF, MDEF, GDEF, FuncItem, TypeAliasExpr, NewTypeExpr, ImportAll, EnumCallExpr,
95+
LDEF, MDEF, GDEF, FuncItem, TypeAliasExpr, NewTypeExpr, ImportAll, EnumCallExpr, AwaitExpr,
9696
op_methods, reverse_op_methods, ops_with_inplace_method, unary_op_methods
9797
)
9898
from mypy.traverser import TraverserVisitor
@@ -166,7 +166,6 @@ def __init__(self,
166166
self.is_package_init_file = False
167167

168168
# TODO (incomplete):
169-
# await
170169
# protocols
171170

172171
def visit_mypy_file(self, o: MypyFile) -> None:
@@ -415,10 +414,22 @@ def visit_operator_assignment_stmt(self, o: OperatorAssignmentStmt) -> None:
415414

416415
def visit_for_stmt(self, o: ForStmt) -> None:
417416
super().visit_for_stmt(o)
418-
# __getitem__ is only used if __iter__ is missing but for simplicity we
419-
# just always depend on both.
420-
self.add_attribute_dependency_for_expr(o.expr, '__iter__')
421-
self.add_attribute_dependency_for_expr(o.expr, '__getitem__')
417+
if not o.is_async:
418+
# __getitem__ is only used if __iter__ is missing but for simplicity we
419+
# just always depend on both.
420+
self.add_attribute_dependency_for_expr(o.expr, '__iter__')
421+
self.add_attribute_dependency_for_expr(o.expr, '__getitem__')
422+
if o.inferred_iterator_type:
423+
if self.python2:
424+
method = 'next'
425+
else:
426+
method = '__next__'
427+
self.add_attribute_dependency(o.inferred_iterator_type, method)
428+
else:
429+
self.add_attribute_dependency_for_expr(o.expr, '__aiter__')
430+
if o.inferred_iterator_type:
431+
self.add_attribute_dependency(o.inferred_iterator_type, '__anext__')
432+
422433
self.process_lvalue(o.index)
423434
if isinstance(o.index, TupleExpr):
424435
# Process multiple assignment to index variables.
@@ -433,8 +444,12 @@ def visit_for_stmt(self, o: ForStmt) -> None:
433444
def visit_with_stmt(self, o: WithStmt) -> None:
434445
super().visit_with_stmt(o)
435446
for e in o.expr:
436-
self.add_attribute_dependency_for_expr(e, '__enter__')
437-
self.add_attribute_dependency_for_expr(e, '__exit__')
447+
if not o.is_async:
448+
self.add_attribute_dependency_for_expr(e, '__enter__')
449+
self.add_attribute_dependency_for_expr(e, '__exit__')
450+
else:
451+
self.add_attribute_dependency_for_expr(e, '__aenter__')
452+
self.add_attribute_dependency_for_expr(e, '__aexit__')
438453
if o.target_type:
439454
self.add_type_dependencies(o.target_type)
440455

@@ -584,6 +599,10 @@ def visit_yield_from_expr(self, e: YieldFromExpr) -> None:
584599
super().visit_yield_from_expr(e)
585600
self.add_iter_dependency(e.expr)
586601

602+
def visit_await_expr(self, e: AwaitExpr) -> None:
603+
super().visit_await_expr(e)
604+
self.add_attribute_dependency_for_expr(e.expr, '__await__')
605+
587606
# Helpers
588607

589608
def add_type_alias_deps(self, target: str) -> None:

test-data/unit/check-async-await.test

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,7 @@ async def f() -> None:
163163
[builtins fixtures/async_await.pyi]
164164
[typing fixtures/typing-full.pyi]
165165
[out]
166-
main:4: error: AsyncIterable expected
167-
main:4: error: "List[int]" has no attribute "__aiter__"
166+
main:4: error: "List[int]" has no attribute "__aiter__" (not async iterable)
168167

169168
[case testAsyncForTypeComments]
170169

@@ -247,14 +246,10 @@ async def wrong_iterable(obj: Iterable[int]):
247246
{i: i for i in asyncify(obj)}
248247

249248
[out]
250-
main:18: error: AsyncIterable expected
251-
main:18: error: "Iterable[int]" has no attribute "__aiter__"; maybe "__iter__"?
252-
main:19: error: Iterable expected
253-
main:19: error: "asyncify[int]" has no attribute "__iter__"; maybe "__aiter__"?
254-
main:20: error: AsyncIterable expected
255-
main:20: error: "Iterable[int]" has no attribute "__aiter__"; maybe "__iter__"?
256-
main:21: error: Iterable expected
257-
main:21: error: "asyncify[int]" has no attribute "__iter__"; maybe "__aiter__"?
249+
main:18: error: "Iterable[int]" has no attribute "__aiter__" (not async iterable)
250+
main:19: error: "asyncify[int]" has no attribute "__iter__"; maybe "__aiter__"? (not iterable)
251+
main:20: error: "Iterable[int]" has no attribute "__aiter__" (not async iterable)
252+
main:21: error: "asyncify[int]" has no attribute "__iter__"; maybe "__aiter__"? (not iterable)
258253
[builtins fixtures/async_await.pyi]
259254
[typing fixtures/typing-full.pyi]
260255

@@ -538,7 +533,7 @@ async def h() -> None:
538533
from typing import AsyncGenerator
539534

540535
async def gen() -> AsyncGenerator[int, None]:
541-
for i in (1, 2, 3):
536+
for i in [1, 2, 3]:
542537
yield i
543538

544539
def h() -> None:
@@ -549,8 +544,7 @@ def h() -> None:
549544
[typing fixtures/typing-full.pyi]
550545

551546
[out]
552-
main:9: error: Iterable expected
553-
main:9: error: "AsyncGenerator[int, None]" has no attribute "__iter__"; maybe "__aiter__"?
547+
main:9: error: "AsyncGenerator[int, None]" has no attribute "__iter__"; maybe "__aiter__"? (not iterable)
554548

555549
[case testAsyncGeneratorNoYieldFrom]
556550
# flags: --fast-parser --python-version 3.6

test-data/unit/check-incomplete-fixture.test

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,7 @@ for y in x:
7676
-- avoid things getting worse.
7777
main:2: error: "tuple" expects no type arguments, but 1 given
7878
main:3: error: Value of type "tuple" is not indexable
79-
main:4: error: Iterable expected
80-
main:4: error: "tuple" has no attribute "__iter__"
79+
main:4: error: "tuple" has no attribute "__iter__" (not iterable)
8180

8281
[case testClassmethodMissingFromStubs]
8382
class A:

test-data/unit/check-inference.test

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,7 +1563,7 @@ class A:
15631563
pass
15641564
[builtins fixtures/for.pyi]
15651565
[out]
1566-
main:5: error: "None" has no attribute "__iter__"
1566+
main:5: error: "None" has no attribute "__iter__" (not iterable)
15671567

15681568
[case testPartialTypeErrorSpecialCase2]
15691569
# This used to crash.
@@ -1584,7 +1584,7 @@ class A:
15841584
pass
15851585
[builtins fixtures/for.pyi]
15861586
[out]
1587-
main:4: error: "None" has no attribute "__iter__"
1587+
main:4: error: "None" has no attribute "__iter__" (not iterable)
15881588

15891589

15901590
-- Multipass

test-data/unit/check-unions.test

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -869,8 +869,7 @@ a: Any
869869
d: Dict[str, Tuple[List[Tuple[str, str]], str]]
870870
x, _ = d.get(a, (None, None))
871871

872-
for y in x: pass # E: Iterable expected \
873-
# E: Item "None" of "Optional[List[Tuple[str, str]]]" has no attribute "__iter__"
872+
for y in x: pass # E: Item "None" of "Optional[List[Tuple[str, str]]]" has no attribute "__iter__" (not iterable)
874873
if x:
875874
for s, t in x:
876875
reveal_type(s) # E: Revealed type is 'builtins.str'
@@ -886,8 +885,7 @@ x = None
886885
d: Dict[str, Tuple[List[Tuple[str, str]], str]]
887886
x, _ = d.get(a, (None, None))
888887

889-
for y in x: pass # E: Iterable expected \
890-
# E: Item "None" of "Optional[List[Tuple[str, str]]]" has no attribute "__iter__"
888+
for y in x: pass # E: Item "None" of "Optional[List[Tuple[str, str]]]" has no attribute "__iter__" (not iterable)
891889
if x:
892890
for s, t in x:
893891
reveal_type(s) # E: Revealed type is 'builtins.str'

test-data/unit/deps.test

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,49 @@ class B(Base):
10251025
<mod.Base> -> <m.f>, m.f
10261026
<mod> -> m
10271027

1028+
[case testCustomIterator]
1029+
class A:
1030+
def __iter__(self) -> B: pass
1031+
class B:
1032+
def __iter__(self) -> B: pass
1033+
def __next__(self) -> C: pass
1034+
class C:
1035+
pass
1036+
def f() -> None:
1037+
for x in A(): pass
1038+
[out]
1039+
<m.A.__getitem__> -> m.f
1040+
<m.A.__init__> -> m.f
1041+
<m.A.__iter__> -> m.f
1042+
<m.A.__new__> -> m.f
1043+
<m.A> -> m.A, m.f
1044+
<m.B.__next__> -> m.f
1045+
<m.B> -> <m.A.__iter__>, <m.B.__iter__>, m.A.__iter__, m.B, m.B.__iter__
1046+
<m.C> -> <m.B.__next__>, m.B.__next__, m.C
1047+
1048+
[case testCustomIterator_python2]
1049+
class A:
1050+
def __iter__(self): # type: () -> B
1051+
pass
1052+
class B:
1053+
def __iter__(self): # type: () -> B
1054+
pass
1055+
def next(self): # type: () -> C
1056+
pass
1057+
class C:
1058+
pass
1059+
def f(): # type: () -> None
1060+
for x in A(): pass
1061+
[out]
1062+
<m.A.__getitem__> -> m.f
1063+
<m.A.__init__> -> m.f
1064+
<m.A.__iter__> -> m.f
1065+
<m.A.__new__> -> m.f
1066+
<m.A> -> m.A, m.f
1067+
<m.B.next> -> m.f
1068+
<m.B> -> <m.A.__iter__>, <m.B.__iter__>, m.A.__iter__, m.B, m.B.__iter__
1069+
<m.C> -> <m.B.next>, m.B.next, m.C
1070+
10281071
[case testDepsLiskovClass]
10291072
from mod import A, C
10301073
class D(C):

0 commit comments

Comments
 (0)