From 381ec52376185af735baa0feb33c1b65cd3dca0e Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 3 Jun 2016 13:52:24 -0700 Subject: [PATCH 01/28] PEP 492 syntax: `async def` and `await`. --- mypy/checker.py | 96 +++++++++++++++++++++++--- mypy/fastparse.py | 43 +++++++++--- mypy/messages.py | 1 + mypy/nodes.py | 41 +++++++++++ mypy/parse.py | 4 ++ mypy/semanal.py | 21 +++++- mypy/strconv.py | 3 + mypy/test/testcheck.py | 1 + mypy/treetransform.py | 6 +- mypy/visitor.py | 3 + test-data/unit/check-async-await.test | 57 +++++++++++++++ test-data/unit/fixtures/async_await.py | 5 ++ test-data/unit/lib-stub/typing.py | 4 ++ 13 files changed, 264 insertions(+), 21 deletions(-) create mode 100644 test-data/unit/check-async-await.test create mode 100644 test-data/unit/fixtures/async_await.py diff --git a/mypy/checker.py b/mypy/checker.py index 5bacc9d2bb62..0eba8b003f6d 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -25,6 +25,7 @@ YieldFromExpr, NamedTupleExpr, SetComprehension, DictionaryComprehension, ComplexExpr, EllipsisExpr, TypeAliasExpr, RefExpr, YieldExpr, BackquoteExpr, ImportFrom, ImportAll, ImportBase, + AwaitExpr, CONTRAVARIANT, COVARIANT ) from mypy.nodes import function_type, method_type, method_type_with_fallback @@ -257,12 +258,45 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: self.msg.overloaded_signatures_overlap(i + 1, i + j + 2, item.func) + # Here's the scoop about generators and coroutines. + # + # There are two kinds of generators: classic generators (functions + # with `yield` or `yield from` in the body) and coroutines + # (functions declared with `async def`). The latter are specified + # in PEP 492 and only available in Python >= 3.5. + # + # Classic generators can be parameterized with three types: + # - ty is the yield type (the type of y in `yield y`) + # - ts is the type receive by yield (the type of s in `s = yield`) + # - tr is the return type (the type of r in `return r`) + # + # A classic generator must define a return type that's either + # `Generator[ty, ts, tr]`, Iterator[ty], or Iterable[ty] (or + # object or Any). If ts/tr are not given, both are Any. + # + # A coroutine must define a return type corresponding to tr; the + # other two are fixed: ty is AbstractFuture[tr], ts in Any. The + # "external" return type (seen by the caller) is Awaitable[tr]. + # + # There are several useful methods: + # + # - is_generator_return_type(t) returns whether t is a Generator, + # Iterator, Iterable, or Awaitable. + # - get_generator_yield_type(t) returns ty. + # - get_generator_receive_type(t) returns ts. + # - get_generator_return_type(t) returns tr. + def is_generator_return_type(self, typ: Type) -> bool: - return is_subtype(self.named_generic_type('typing.Generator', - [AnyType(), AnyType(), AnyType()]), - typ) + """Is `typ` a valid type for generator? + + True if either Generator or Awaitable is a supertype of `typ`. + """ + gt = self.named_generic_type('typing.Generator', [AnyType(), AnyType(), AnyType()]) + at = self.named_generic_type('typing.Awaitable', [AnyType()]) + return is_subtype(gt, typ) or is_subtype(at, typ) def get_generator_yield_type(self, return_type: Type) -> Type: + """Given the declared return type of a generator (t), return the type it yields (ty).""" if isinstance(return_type, AnyType): return AnyType() elif not self.is_generator_return_type(return_type): @@ -272,6 +306,8 @@ def get_generator_yield_type(self, return_type: Type) -> Type: elif not isinstance(return_type, Instance): # Same as above, but written as a separate branch so the typechecker can understand. return AnyType() + elif return_type.type.fullname() == 'typing.Awaitable': + return AnyType() # XXX TODO: AbstractFuture[tr] elif return_type.args: return return_type.args[0] else: @@ -281,6 +317,7 @@ def get_generator_yield_type(self, return_type: Type) -> Type: return AnyType() def get_generator_receive_type(self, return_type: Type) -> Type: + """Given a declared generator return type (t), return the type its yield receives (ts).""" if isinstance(return_type, AnyType): return AnyType() elif not self.is_generator_return_type(return_type): @@ -291,14 +328,22 @@ def get_generator_receive_type(self, return_type: Type) -> Type: # Same as above, but written as a separate branch so the typechecker can understand. return AnyType() elif return_type.type.fullname() == 'typing.Generator': - # Generator is the only type which specifies the type of values it can receive. - return return_type.args[1] + # Generator is one of the two types which specify the type of values it can receive. + if len(return_type.args) == 3: + return return_type.args[1] + else: + return AnyType() + elif return_type.type.fullname() == 'typing.Awaitable': + # Generator is one of the two types which specify the type of values it can receive. + # According to the stub this is always `Any`. + return AnyType() else: # `return_type` is a supertype of Generator, so callers won't be able to send it # values. return Void() def get_generator_return_type(self, return_type: Type) -> Type: + """Given the declared return type of a generator (t), return the type it returns (tr).""" if isinstance(return_type, AnyType): return AnyType() elif not self.is_generator_return_type(return_type): @@ -309,9 +354,19 @@ def get_generator_return_type(self, return_type: Type) -> Type: # Same as above, but written as a separate branch so the typechecker can understand. return AnyType() elif return_type.type.fullname() == 'typing.Generator': - # Generator is the only type which specifies the type of values it returns into - # `yield from` expressions. - return return_type.args[2] + # Generator is one of the two types which specify the type of values it returns into + # `yield from` expressions (using a `return` statements). + if len(return_type.args) == 3: + return return_type.args[2] + else: + return AnyType() + elif return_type.type.fullname() == 'typing.Awaitable': + # Awaitable is the other type which specifies the type of values it returns into + # `yield from` expressions (using `return`). + if len(return_type.args) == 1: + return return_type.args[0] + else: + return AnyType() else: # `return_type` is supertype of Generator, so callers won't be able to see the return # type when used in a `yield from` expression. @@ -1909,6 +1964,31 @@ def visit_yield_expr(self, e: YieldExpr) -> Type: 'actual type', 'expected type') return self.get_generator_receive_type(return_type) + def visit_await_expr(self, e: AwaitExpr) -> Type: + any_type = AnyType() + awaitable_type = self.named_generic_type('typing.Awaitable', [any_type]) + generator_type = self.named_generic_type('typing.Generator', [any_type] * 3) + # XXX Instead of a union, check for both types separately and produce custom message + expected_type = UnionType([awaitable_type, generator_type]) + actual_type = self.accept(e.expr, expected_type) + if isinstance(actual_type, AnyType): + return any_type + if is_subtype(actual_type, generator_type): + if isinstance(actual_type, Instance) and len(actual_type.args) == 3: + return actual_type.args[2] + else: + return any_type # Must've been unparameterized Generator. + elif is_subtype(actual_type, awaitable_type): + if isinstance(actual_type, Instance) and len(actual_type.args) == 1: + return actual_type.args[0] + else: + return any_type # Must've been unparameterized Awaitable. + msg = "{} (actual type {}, expected Awaitable or Generator)".format( + messages.INCOMPATIBLE_TYPES_IN_AWAIT, + self.msg.format(actual_type)) + self.fail(msg, e) + return any_type + # # Helpers # diff --git a/mypy/fastparse.py b/mypy/fastparse.py index c76e8b98f583..3e27698f2fc0 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -14,9 +14,12 @@ UnaryExpr, FuncExpr, ComparisonExpr, StarExpr, YieldFromExpr, NonlocalDecl, DictionaryComprehension, SetComprehension, ComplexExpr, EllipsisExpr, YieldExpr, Argument, + AwaitExpr, AsyncForStmt, AsyncWithStmt, ARG_POS, ARG_OPT, ARG_STAR, ARG_NAMED, ARG_STAR2 ) -from mypy.types import Type, CallableType, AnyType, UnboundType, TupleType, TypeList, EllipsisType +from mypy.types import ( + Type, CallableType, FunctionLike, AnyType, UnboundType, TupleType, TypeList, EllipsisType, +) from mypy import defaults from mypy import experiments from mypy.errors import Errors @@ -242,6 +245,11 @@ def visit_Module(self, mod: ast35.Module) -> Node: # arg? kwarg, expr* defaults) @with_line def visit_FunctionDef(self, n: ast35.FunctionDef) -> Node: + return self.do_func_def(n) + + def do_func_def(self, n: Union[ast35.FunctionDef, ast35.AsyncFunctionDef], + is_coroutine: bool = False) -> Node: + """Helper shared between visit_FunctionDef and visit_AsyncFunctionDef.""" args = self.transform_args(n.args, n.lineno) arg_kinds = [arg.kind for arg in args] @@ -285,6 +293,9 @@ def visit_FunctionDef(self, n: ast35.FunctionDef) -> Node: args, self.as_block(n.body, n.lineno), func_type) + if is_coroutine: + # A coroutine is also a generator, mostly for internal reasons. + func_def.is_generator = func_def.is_coroutine = True if func_type is not None: func_type.definition = func_def func_type.line = n.lineno @@ -345,9 +356,6 @@ def make_argument(arg: ast35.arg, default: Optional[ast35.expr], kind: int) -> A return new_args - # TODO: AsyncFunctionDef(identifier name, arguments args, - # stmt* body, expr* decorator_list, expr? returns, string? type_comment) - def stringify_name(self, n: ast35.AST) -> str: if isinstance(n, ast35.Name): return n.id @@ -419,7 +427,6 @@ def visit_For(self, n: ast35.For) -> Node: self.as_block(n.body, n.lineno), self.as_block(n.orelse, n.lineno)) - # TODO: AsyncFor(expr target, expr iter, stmt* body, stmt* orelse) # While(expr test, stmt* body, stmt* orelse) @with_line def visit_While(self, n: ast35.While) -> Node: @@ -441,8 +448,6 @@ def visit_With(self, n: ast35.With) -> Node: [self.visit(i.optional_vars) for i in n.items], self.as_block(n.body, n.lineno)) - # TODO: AsyncWith(withitem* items, stmt* body) - # Raise(expr? exc, expr? cause) @with_line def visit_Raise(self, n: ast35.Raise) -> Node: @@ -628,8 +633,6 @@ def visit_GeneratorExp(self, n: ast35.GeneratorExp) -> GeneratorExpr: iters, ifs_list) - # TODO: Await(expr value) - # Yield(expr? value) @with_line def visit_Yield(self, n: ast35.Yield) -> Node: @@ -762,6 +765,28 @@ def visit_ExtSlice(self, n: ast35.ExtSlice) -> Node: def visit_Index(self, n: ast35.Index) -> Node: return self.visit(n.value) + # PEP 492 nodes: 'async def', 'await', 'async for', 'async with'. + + # AsyncFunctionDef(identifier name, arguments args, + # stmt* body, expr* decorator_list, expr? returns, string? type_comment) + def visit_AsyncFunctionDef(self, n: ast35.AsyncFunctionDef) -> Node: + return self.do_func_def(n, is_coroutine=True) + + # Await(expr value) + def visit_Await(self, n: ast35.Await) -> Node: + v = self.visit(n.value) + return AwaitExpr(v) + + # AsyncFor(expr target, expr iter, stmt* body, stmt* orelse) + def visit_AsyncFor(self, n: ast35.AsyncFor) -> Node: + self.visit_list(n.body) # XXX + return AsyncForStmt() # XXX + + # AsyncWith(withitem* items, stmt* body) + def visit_AsyncWith(self, n: ast35.AsyncWith) -> Node: + self.visit_list(n.body) # XXX + return AsyncWithStmt() # XXX + class TypeConverter(ast35.NodeTransformer): def __init__(self, line: int = -1) -> None: diff --git a/mypy/messages.py b/mypy/messages.py index d37dcdb2be77..a78ccba8cbfd 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -40,6 +40,7 @@ INCOMPATIBLE_TYPES = 'Incompatible types' INCOMPATIBLE_TYPES_IN_ASSIGNMENT = 'Incompatible types in assignment' INCOMPATIBLE_REDEFINITION = 'Incompatible redefinition' +INCOMPATIBLE_TYPES_IN_AWAIT = 'Incompatible types in await' INCOMPATIBLE_TYPES_IN_YIELD = 'Incompatible types in yield' INCOMPATIBLE_TYPES_IN_YIELD_FROM = 'Incompatible types in "yield from"' INCOMPATIBLE_TYPES_IN_STR_INTERPOLATION = 'Incompatible types in string interpolation' diff --git a/mypy/nodes.py b/mypy/nodes.py index fe4da0d227d4..bc4a186a3f77 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -416,6 +416,7 @@ class FuncItem(FuncBase): # Is this an overload variant of function with more than one overload variant? is_overload = False is_generator = False # Contains a yield statement? + is_coroutine = False # Defined using 'async def'? is_static = False # Uses @staticmethod? is_class = False # Uses @classmethod? # Variants of function with type variables with values expanded @@ -1705,6 +1706,46 @@ def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit__promote_expr(self) +# PEP 492 nodes: 'await', 'async for', 'async with'. +# ('async def' is FunDef(..., is_coroutine=True).) + + +class AwaitExpr(Node): + """Await expression (await ...).""" + # TODO: [de]serialize() + # TODO: arg type must be AsyncIterator[E] or Generator[E, ..., ...] (or Iterator[E]?) + # and then the return type is E + + expr = None # type: Node + type = None # type: Type + + def __init__(self, expr: Node) -> None: + self.expr = expr + + def accept(self, visitor: NodeVisitor[T]) -> T: + return visitor.visit_await_expr(self) + + +class AsyncForStmt(Node): + """Asynchronous for statement (async for ...).""" + # TODO: constructor + # TODO: [de]serialize + # TODO: types + + def accept(self, visitor: NodeVisitor[T]) -> T: + print("XXX AsyncForStmt.accept") + + +class AsyncWithStmt(Node): + """Asynchronous with statement (async with ...).""" + # TODO: constructor + # TODO: [de]serialize + # TODO: types + + def accept(self, visitor: NodeVisitor[T]) -> T: + print("XXX AsyncWithStmt.accept") + + # Constants diff --git a/mypy/parse.py b/mypy/parse.py index 3007902fde6f..d76ca55c75a3 100644 --- a/mypy/parse.py +++ b/mypy/parse.py @@ -957,6 +957,10 @@ def parse_statement(self) -> Tuple[Node, bool]: stmt = self.parse_exec_stmt() else: stmt = self.parse_expression_or_assignment() + if ts == 'async' and self.current_str() == 'def': + self.parse_error_at(self.current(), + reason='Use --fast-parser to parse code using "async def"') + raise ParseError() if stmt is not None: stmt.set_line(t) return stmt, is_simple diff --git a/mypy/semanal.py b/mypy/semanal.py index 99cba07d0e4f..4738421841ca 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -62,16 +62,16 @@ ComparisonExpr, StarExpr, ARG_POS, ARG_NAMED, MroError, type_aliases, YieldFromExpr, NamedTupleExpr, NonlocalDecl, SetComprehension, DictionaryComprehension, TYPE_ALIAS, TypeAliasExpr, - YieldExpr, ExecStmt, Argument, BackquoteExpr, ImportBase, COVARIANT, CONTRAVARIANT, + YieldExpr, ExecStmt, Argument, BackquoteExpr, ImportBase, AwaitExpr, IntExpr, FloatExpr, UnicodeExpr, - INVARIANT, UNBOUND_IMPORTED + COVARIANT, CONTRAVARIANT, INVARIANT, UNBOUND_IMPORTED ) from mypy.visitor import NodeVisitor from mypy.traverser import TraverserVisitor from mypy.errors import Errors, report_internal_error from mypy.types import ( NoneTyp, CallableType, Overloaded, Instance, Type, TypeVarType, AnyType, - FunctionLike, UnboundType, TypeList, ErrorType, TypeVarDef, + FunctionLike, UnboundType, TypeList, ErrorType, TypeVarDef, Void, replace_leading_arg_type, TupleType, UnionType, StarType, EllipsisType ) from mypy.nodes import function_type, implicit_module_attrs @@ -314,6 +314,13 @@ def visit_func_def(self, defn: FuncDef) -> None: # Second phase of analysis for function. self.errors.push_function(defn.name()) self.analyze_function(defn) + if defn.is_coroutine: + # A coroutine defined as `async def foo(...) -> T: ...` + # has external return type `Awaitable[T]`. + defn.type = defn.type.copy_modified( + ret_type=Instance( + self.named_type_or_none('typing.Awaitable').type, + [defn.type.ret_type])) self.errors.pop_function() def prepare_method_signature(self, func: FuncDef) -> None: @@ -2072,6 +2079,14 @@ def visit_yield_expr(self, expr: YieldExpr) -> None: if expr.expr: expr.expr.accept(self) + def visit_await_expr(self, expr: AwaitExpr) -> None: + if not self.is_func_scope(): + self.fail("'await' outside function", expr) + elif not self.function_stack[-1].is_coroutine: + self.fail("'await' outside coroutine ('async def')", expr) + if expr.expr: + expr.expr.accept(self) + # # Helpers # diff --git a/mypy/strconv.py b/mypy/strconv.py index 8d2c0845d70d..e898a7e6c40f 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -243,6 +243,9 @@ def visit_yield_from_stmt(self, o): def visit_yield_expr(self, o): return self.dump([o.expr], o) + def visit_await_expr(self, o): + return self.dump([o.expr], o) + def visit_del_stmt(self, o): return self.dump([o.expr], o) diff --git a/mypy/test/testcheck.py b/mypy/test/testcheck.py index 0136e5a9f147..ec054264c322 100644 --- a/mypy/test/testcheck.py +++ b/mypy/test/testcheck.py @@ -63,6 +63,7 @@ 'check-optional.test', 'check-fastparse.test', 'check-warnings.test', + 'check-async-await.test', ] diff --git a/mypy/treetransform.py b/mypy/treetransform.py index 829b86dc4793..07c901169a7a 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -19,7 +19,8 @@ ComparisonExpr, TempNode, StarExpr, YieldFromExpr, NamedTupleExpr, NonlocalDecl, SetComprehension, DictionaryComprehension, ComplexExpr, TypeAliasExpr, EllipsisExpr, - YieldExpr, ExecStmt, Argument, BackquoteExpr + YieldExpr, ExecStmt, Argument, BackquoteExpr, + AwaitExpr ) from mypy.types import Type, FunctionLike, Instance from mypy.visitor import NodeVisitor @@ -339,6 +340,9 @@ def visit_yield_from_expr(self, node: YieldFromExpr) -> Node: def visit_yield_expr(self, node: YieldExpr) -> Node: return YieldExpr(self.node(node.expr)) + def visit_await_expr(self, node: AwaitExpr) -> Type: + return AwaitExpr(self.node(node.expr)) + def visit_call_expr(self, node: CallExpr) -> Node: return CallExpr(self.node(node.callee), self.nodes(node.args), diff --git a/mypy/visitor.py b/mypy/visitor.py index b1e1b883a109..43e7c161ea6d 100644 --- a/mypy/visitor.py +++ b/mypy/visitor.py @@ -228,5 +228,8 @@ def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> T: def visit__promote_expr(self, o: 'mypy.nodes.PromoteExpr') -> T: pass + def visit_await_expr(self, o: 'mypy.nodes.AwaitExpr') -> T: + pass + def visit_temp_node(self, o: 'mypy.nodes.TempNode') -> T: pass diff --git a/test-data/unit/check-async-await.test b/test-data/unit/check-async-await.test new file mode 100644 index 000000000000..f2ac637a2651 --- /dev/null +++ b/test-data/unit/check-async-await.test @@ -0,0 +1,57 @@ +-- Tests for async def and await (PEP 492) +-- --------------------------------------- + +[case testAsyncDefPass] +# options: fast_parser +async def f() -> int: + pass +[builtins fixtures/async_await.py] + +[case testAsyncDefReturn] +# options: fast_parser +async def f() -> int: + return 0 +reveal_type(f()) # E: Revealed type is 'typing.Awaitable[builtins.int]' +[builtins fixtures/async_await.py] + +[case testAwaitCoroutine] +# options: fast_parser +async def f() -> int: + x = await f() + return x +[builtins fixtures/async_await.py] + +[case testAwaitGenerator] +# options: fast_parser +from typing import Any, Generator +def g() -> Generator[int, int, int]: + x = yield 0 + return x +async def f() -> int: + x = await g() + return x +[builtins fixtures/async_await.py] + +[case testAwaitArgumentError] +# options: fast_parser +def g() -> int: + return 0 +async def f() -> int: + x = await g() + return x +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": +main: error: Incompatible types in await (actual type "int", expected Awaitable or Generator) + +[case testAwaitResultError] +# options: fast_parser +async def g() -> int: + return 0 +async def f() -> str: + x = await g() + return x +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": +main:6: error: Incompatible return value type (got "int", expected "str") diff --git a/test-data/unit/fixtures/async_await.py b/test-data/unit/fixtures/async_await.py new file mode 100644 index 000000000000..742cdb42fe38 --- /dev/null +++ b/test-data/unit/fixtures/async_await.py @@ -0,0 +1,5 @@ +import typing +class function: pass +class int: pass +class object: pass +class str: pass diff --git a/test-data/unit/lib-stub/typing.py b/test-data/unit/lib-stub/typing.py index 09c76a7eb1bf..fe68df81279f 100644 --- a/test-data/unit/lib-stub/typing.py +++ b/test-data/unit/lib-stub/typing.py @@ -57,6 +57,10 @@ def close(self) -> None: pass @abstractmethod def __iter__(self) -> 'Generator[T, U, V]': pass +class Awaitable(Generic[T]): + @abstractmethod + def __await__(self) -> Generator[Any, Any, T]: pass + class Sequence(Generic[T]): @abstractmethod def __getitem__(self, n: Any) -> T: pass From be91786a6c1f1ca237adc3f8ff8adda336c245ef Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 1 Jul 2016 17:38:33 -0700 Subject: [PATCH 02/28] Fix type errors; add async for and async with (not fully fledged). --- mypy/checker.py | 22 ++++++++++++++- mypy/fastparse.py | 11 +++++--- mypy/nodes.py | 37 +++++++++++++++++++++---- mypy/semanal.py | 40 +++++++++++++++++++++++++-- mypy/traverser.py | 16 ++++++++++- mypy/treetransform.py | 2 +- mypy/visitor.py | 6 ++++ test-data/unit/check-async-await.test | 10 +++++++ 8 files changed, 129 insertions(+), 15 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 0eba8b003f6d..9582e5a289f4 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -25,7 +25,7 @@ YieldFromExpr, NamedTupleExpr, SetComprehension, DictionaryComprehension, ComplexExpr, EllipsisExpr, TypeAliasExpr, RefExpr, YieldExpr, BackquoteExpr, ImportFrom, ImportAll, ImportBase, - AwaitExpr, + AwaitExpr, AsyncForStmt, AsyncWithStmt, CONTRAVARIANT, COVARIANT ) from mypy.nodes import function_type, method_type, method_type_with_fallback @@ -1663,6 +1663,13 @@ def visit_for_stmt(self, s: ForStmt) -> Type: self.analyze_index_variables(s.index, item_type, s) self.accept_loop(s.body, s.else_body) + def visit_async_for_stmt(self, s: AsyncForStmt) -> Type: + """Type check an `async for` statement.""" + # TODO: Type??? + item_type = self.analyze_iterable_item_type(s.expr) + self.analyze_index_variables(s.index, item_type, s) + self.accept_loop(s.body, s.else_body) + def analyze_iterable_item_type(self, expr: Node) -> Type: """Analyse iterable expression and return iterator item type.""" iterable = self.accept(expr) @@ -1781,6 +1788,19 @@ def visit_with_stmt(self, s: WithStmt) -> Type: echk.check_call(exit, [arg] * 3, [nodes.ARG_POS] * 3, expr) self.accept(s.body) + def visit_async_with_stmt(self, s: AsyncWithStmt) -> Type: + echk = self.expr_checker + for expr, target in zip(s.expr, s.target): + ctx = self.accept(expr) + enter = echk.analyze_external_member_access('__enter__', ctx, expr) + obj = echk.check_call(enter, [], [], expr)[0] + if target: + self.check_assignment(target, self.temp_node(obj, expr)) + exit = echk.analyze_external_member_access('__exit__', ctx, expr) + arg = self.temp_node(AnyType(), expr) + echk.check_call(exit, [arg] * 3, [nodes.ARG_POS] * 3, expr) + self.accept(s.body) + def visit_print_stmt(self, s: PrintStmt) -> Type: for arg in s.args: self.accept(arg) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 3e27698f2fc0..cd37e102ad0e 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -779,13 +779,16 @@ def visit_Await(self, n: ast35.Await) -> Node: # AsyncFor(expr target, expr iter, stmt* body, stmt* orelse) def visit_AsyncFor(self, n: ast35.AsyncFor) -> Node: - self.visit_list(n.body) # XXX - return AsyncForStmt() # XXX + return AsyncForStmt(self.visit(n.target), + self.visit(n.iter), + self.as_block(n.body, n.lineno), + self.as_block(n.orelse, n.lineno)) # AsyncWith(withitem* items, stmt* body) def visit_AsyncWith(self, n: ast35.AsyncWith) -> Node: - self.visit_list(n.body) # XXX - return AsyncWithStmt() # XXX + return AsyncWithStmt([self.visit(i.context_expr) for i in n.items], + [self.visit(i.optional_vars) for i in n.items], + self.as_block(n.body, n.lineno)) class TypeConverter(ast35.NodeTransformer): diff --git a/mypy/nodes.py b/mypy/nodes.py index bc4a186a3f77..15845909b0a6 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -487,6 +487,7 @@ def serialize(self) -> JsonDict: 'is_property': self.is_property, 'is_overload': self.is_overload, 'is_generator': self.is_generator, + 'is_coroutine': self.is_coroutine, 'is_static': self.is_static, 'is_class': self.is_class, 'is_decorated': self.is_decorated, @@ -508,6 +509,7 @@ def deserialize(cls, data: JsonDict) -> 'FuncDef': ret.is_property = data['is_property'] ret.is_overload = data['is_overload'] ret.is_generator = data['is_generator'] + ret.is_coroutine = data['is_coroutine'] ret.is_static = data['is_static'] ret.is_class = data['is_class'] ret.is_decorated = data['is_decorated'] @@ -1707,7 +1709,7 @@ def accept(self, visitor: NodeVisitor[T]) -> T: # PEP 492 nodes: 'await', 'async for', 'async with'. -# ('async def' is FunDef(..., is_coroutine=True).) +# ('async def' is a FuncDef with is_coroutine = True.) class AwaitExpr(Node): @@ -1717,7 +1719,6 @@ class AwaitExpr(Node): # and then the return type is E expr = None # type: Node - type = None # type: Type def __init__(self, expr: Node) -> None: self.expr = expr @@ -1728,22 +1729,46 @@ def accept(self, visitor: NodeVisitor[T]) -> T: class AsyncForStmt(Node): """Asynchronous for statement (async for ...).""" - # TODO: constructor + # TODO Maybe just use ForStmt with an extra flag? # TODO: [de]serialize # TODO: types + # Index variables + index = None # type: Expression + # Expression to iterate + expr = None # type: Expression + body = None # type: Block + else_body = None # type: Block + + def __init__(self, index: Expression, expr: Expression, body: Block, + else_body: Block) -> None: + self.index = index + self.expr = expr + self.body = body + self.else_body = else_body + def accept(self, visitor: NodeVisitor[T]) -> T: - print("XXX AsyncForStmt.accept") + return visitor.visit_async_for_stmt(self) class AsyncWithStmt(Node): """Asynchronous with statement (async with ...).""" - # TODO: constructor + # TODO Maybe just use WithStmt with an extra flag? # TODO: [de]serialize # TODO: types + expr = None # type: List[Expression] + target = None # type: List[Expression] + body = None # type: Block + + def __init__(self, expr: List[Expression], target: List[Expression], + body: Block) -> None: + self.expr = expr + self.target = target + self.body = body + def accept(self, visitor: NodeVisitor[T]) -> T: - print("XXX AsyncWithStmt.accept") + return visitor.visit_async_with_stmt(self) # Constants diff --git a/mypy/semanal.py b/mypy/semanal.py index 4738421841ca..cff4ab2f4a29 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -62,8 +62,9 @@ ComparisonExpr, StarExpr, ARG_POS, ARG_NAMED, MroError, type_aliases, YieldFromExpr, NamedTupleExpr, NonlocalDecl, SetComprehension, DictionaryComprehension, TYPE_ALIAS, TypeAliasExpr, - YieldExpr, ExecStmt, Argument, BackquoteExpr, ImportBase, AwaitExpr, + YieldExpr, ExecStmt, Argument, BackquoteExpr, ImportBase, IntExpr, FloatExpr, UnicodeExpr, + AwaitExpr, AsyncForStmt, AsyncWithStmt, COVARIANT, CONTRAVARIANT, INVARIANT, UNBOUND_IMPORTED ) from mypy.visitor import NodeVisitor @@ -314,7 +315,7 @@ def visit_func_def(self, defn: FuncDef) -> None: # Second phase of analysis for function. self.errors.push_function(defn.name()) self.analyze_function(defn) - if defn.is_coroutine: + if defn.is_coroutine and isinstance(defn.type, CallableType): # A coroutine defined as `async def foo(...) -> T: ...` # has external return type `Awaitable[T]`. defn.type = defn.type.copy_modified( @@ -1686,6 +1687,19 @@ def visit_for_stmt(self, s: ForStmt) -> None: self.visit_block_maybe(s.else_body) + def visit_async_for_stmt(self, s: AsyncForStmt) -> None: + # TODO: type of expr. + s.expr.accept(self) + + # Bind index variables and check if they define new names. + self.analyze_lvalue(s.index) + + self.loop_depth += 1 + self.visit_block(s.body) + self.loop_depth -= 1 + + self.visit_block_maybe(s.else_body) + def visit_break_stmt(self, s: BreakStmt) -> None: if self.loop_depth == 0: self.fail("'break' outside loop", s, True, blocker=True) @@ -1725,6 +1739,14 @@ def visit_with_stmt(self, s: WithStmt) -> None: self.analyze_lvalue(n) self.visit_block(s.body) + def visit_async_with_stmt(self, s: AsyncWithStmt) -> None: + # TODO: Type??? + for e, n in zip(s.expr, s.target): + e.accept(self) + if n: + self.analyze_lvalue(n) + self.visit_block(s.body) + def visit_del_stmt(self, s: DelStmt) -> None: s.expr.accept(self) if not self.is_valid_del_target(s.expr): @@ -2490,12 +2512,26 @@ def visit_for_stmt(self, s: ForStmt) -> None: if s.else_body: s.else_body.accept(self) + def visit_async_for_stmt(self, s: AsyncForStmt) -> None: + # TODO: Type of s.expr + self.analyze_lvalue(s.index) + s.body.accept(self) + if s.else_body: + s.else_body.accept(self) + def visit_with_stmt(self, s: WithStmt) -> None: for n in s.target: if n: self.analyze_lvalue(n) s.body.accept(self) + def visit_async_with_stmt(self, s: AsyncWithStmt) -> None: + # TODO: Type ??? + for n in s.target: + if n: + self.analyze_lvalue(n) + s.body.accept(self) + def visit_decorator(self, d: Decorator) -> None: d.var._fullname = self.sem.qualified_name(d.var.name()) self.sem.add_symbol(d.var.name(), SymbolTableNode(GDEF, d.var), d) diff --git a/mypy/traverser.py b/mypy/traverser.py index ddd4ea7aaa42..f458fde9159a 100644 --- a/mypy/traverser.py +++ b/mypy/traverser.py @@ -9,7 +9,7 @@ UnaryExpr, ListExpr, TupleExpr, DictExpr, SetExpr, IndexExpr, GeneratorExpr, ListComprehension, ConditionalExpr, TypeApplication, FuncExpr, ComparisonExpr, OverloadedFuncDef, YieldFromExpr, - YieldExpr + YieldExpr, AsyncForStmt, AsyncWithStmt, ) @@ -84,6 +84,13 @@ def visit_for_stmt(self, o: ForStmt) -> None: if o.else_body: o.else_body.accept(self) + def visit_async_for_stmt(self, o: AsyncForStmt) -> None: + o.index.accept(self) + o.expr.accept(self) + o.body.accept(self) + if o.else_body: + o.else_body.accept(self) + def visit_return_stmt(self, o: ReturnStmt) -> None: if o.expr is not None: o.expr.accept(self) @@ -128,6 +135,13 @@ def visit_with_stmt(self, o: WithStmt) -> None: o.target[i].accept(self) o.body.accept(self) + def visit_async_with_stmt(self, o: AsyncWithStmt) -> None: + for i in range(len(o.expr)): + o.expr[i].accept(self) + if o.target[i] is not None: + o.target[i].accept(self) + o.body.accept(self) + def visit_member_expr(self, o: MemberExpr) -> None: o.expr.accept(self) diff --git a/mypy/treetransform.py b/mypy/treetransform.py index 07c901169a7a..6ef62ec3d637 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -340,7 +340,7 @@ def visit_yield_from_expr(self, node: YieldFromExpr) -> Node: def visit_yield_expr(self, node: YieldExpr) -> Node: return YieldExpr(self.node(node.expr)) - def visit_await_expr(self, node: AwaitExpr) -> Type: + def visit_await_expr(self, node: AwaitExpr) -> Node: return AwaitExpr(self.node(node.expr)) def visit_call_expr(self, node: CallExpr) -> Node: diff --git a/mypy/visitor.py b/mypy/visitor.py index 43e7c161ea6d..a67e1fc02ae6 100644 --- a/mypy/visitor.py +++ b/mypy/visitor.py @@ -79,6 +79,9 @@ def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> T: def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T: pass + def visit_async_for_stmt(self, o: 'mypy.nodes.AsyncForStmt') -> T: + pass + def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> T: pass @@ -109,6 +112,9 @@ def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> T: def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> T: pass + def visit_async_with_stmt(self, o: 'mypy.nodes.AsyncWithStmt') -> T: + pass + def visit_print_stmt(self, o: 'mypy.nodes.PrintStmt') -> T: pass diff --git a/test-data/unit/check-async-await.test b/test-data/unit/check-async-await.test index f2ac637a2651..4cb97e934147 100644 --- a/test-data/unit/check-async-await.test +++ b/test-data/unit/check-async-await.test @@ -55,3 +55,13 @@ async def f() -> str: [out] main: note: In function "f": main:6: error: Incompatible return value type (got "int", expected "str") + +[case testAsyncFor] +# options: fast_parser +class C: + def __next__(self): return self + async def __aiter__(self) -> int: return 0 +async def f() -> None: + async for x in C(): + pass # TODO: reveal or check type +[builtins fixtures/async_await.py] From cf3b9e039a0847ae8a5facc9efdad278213e36b2 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 2 Jul 2016 09:44:48 -0700 Subject: [PATCH 03/28] Dispose of Async{For,With}Stmt -- use is_async flag instead. --- mypy/checker.py | 23 +------------------- mypy/fastparse.py | 20 ++++++++++------- mypy/nodes.py | 50 ++----------------------------------------- mypy/semanal.py | 40 ++-------------------------------- mypy/traverser.py | 16 +------------- mypy/treetransform.py | 3 +-- mypy/visitor.py | 6 ------ 7 files changed, 19 insertions(+), 139 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 9582e5a289f4..d3c75939fe4c 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -25,7 +25,7 @@ YieldFromExpr, NamedTupleExpr, SetComprehension, DictionaryComprehension, ComplexExpr, EllipsisExpr, TypeAliasExpr, RefExpr, YieldExpr, BackquoteExpr, ImportFrom, ImportAll, ImportBase, - AwaitExpr, AsyncForStmt, AsyncWithStmt, + AwaitExpr, CONTRAVARIANT, COVARIANT ) from mypy.nodes import function_type, method_type, method_type_with_fallback @@ -1663,13 +1663,6 @@ def visit_for_stmt(self, s: ForStmt) -> Type: self.analyze_index_variables(s.index, item_type, s) self.accept_loop(s.body, s.else_body) - def visit_async_for_stmt(self, s: AsyncForStmt) -> Type: - """Type check an `async for` statement.""" - # TODO: Type??? - item_type = self.analyze_iterable_item_type(s.expr) - self.analyze_index_variables(s.index, item_type, s) - self.accept_loop(s.body, s.else_body) - def analyze_iterable_item_type(self, expr: Node) -> Type: """Analyse iterable expression and return iterator item type.""" iterable = self.accept(expr) @@ -1788,19 +1781,6 @@ def visit_with_stmt(self, s: WithStmt) -> Type: echk.check_call(exit, [arg] * 3, [nodes.ARG_POS] * 3, expr) self.accept(s.body) - def visit_async_with_stmt(self, s: AsyncWithStmt) -> Type: - echk = self.expr_checker - for expr, target in zip(s.expr, s.target): - ctx = self.accept(expr) - enter = echk.analyze_external_member_access('__enter__', ctx, expr) - obj = echk.check_call(enter, [], [], expr)[0] - if target: - self.check_assignment(target, self.temp_node(obj, expr)) - exit = echk.analyze_external_member_access('__exit__', ctx, expr) - arg = self.temp_node(AnyType(), expr) - echk.check_call(exit, [arg] * 3, [nodes.ARG_POS] * 3, expr) - self.accept(s.body) - def visit_print_stmt(self, s: PrintStmt) -> Type: for arg in s.args: self.accept(arg) @@ -1988,7 +1968,6 @@ def visit_await_expr(self, e: AwaitExpr) -> Type: any_type = AnyType() awaitable_type = self.named_generic_type('typing.Awaitable', [any_type]) generator_type = self.named_generic_type('typing.Generator', [any_type] * 3) - # XXX Instead of a union, check for both types separately and produce custom message expected_type = UnionType([awaitable_type, generator_type]) actual_type = self.accept(e.expr, expected_type) if isinstance(actual_type, AnyType): diff --git a/mypy/fastparse.py b/mypy/fastparse.py index cd37e102ad0e..0770adf712a4 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -14,7 +14,7 @@ UnaryExpr, FuncExpr, ComparisonExpr, StarExpr, YieldFromExpr, NonlocalDecl, DictionaryComprehension, SetComprehension, ComplexExpr, EllipsisExpr, YieldExpr, Argument, - AwaitExpr, AsyncForStmt, AsyncWithStmt, + AwaitExpr, ARG_POS, ARG_OPT, ARG_STAR, ARG_NAMED, ARG_STAR2 ) from mypy.types import ( @@ -779,16 +779,20 @@ def visit_Await(self, n: ast35.Await) -> Node: # AsyncFor(expr target, expr iter, stmt* body, stmt* orelse) def visit_AsyncFor(self, n: ast35.AsyncFor) -> Node: - return AsyncForStmt(self.visit(n.target), - self.visit(n.iter), - self.as_block(n.body, n.lineno), - self.as_block(n.orelse, n.lineno)) + r = ForStmt(self.visit(n.target), + self.visit(n.iter), + self.as_block(n.body, n.lineno), + self.as_block(n.orelse, n.lineno)) + r.is_async = True + return r # AsyncWith(withitem* items, stmt* body) def visit_AsyncWith(self, n: ast35.AsyncWith) -> Node: - return AsyncWithStmt([self.visit(i.context_expr) for i in n.items], - [self.visit(i.optional_vars) for i in n.items], - self.as_block(n.body, n.lineno)) + r = WithStmt([self.visit(i.context_expr) for i in n.items], + [self.visit(i.optional_vars) for i in n.items], + self.as_block(n.body, n.lineno)) + r.is_async = True + return r class TypeConverter(ast35.NodeTransformer): diff --git a/mypy/nodes.py b/mypy/nodes.py index 15845909b0a6..cbfaba725b63 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -801,6 +801,7 @@ class ForStmt(Statement): expr = None # type: Expression body = None # type: Block else_body = None # type: Block + is_async = False # True if `async for ...` (PEP 492, Python 3.5) def __init__(self, index: Expression, expr: Expression, body: Block, else_body: Block) -> None: @@ -911,6 +912,7 @@ class WithStmt(Statement): expr = None # type: List[Expression] target = None # type: List[Expression] body = None # type: Block + is_async = False # True if `async with ...` (PEP 492, Python 3.5) def __init__(self, expr: List[Expression], target: List[Expression], body: Block) -> None: @@ -1708,10 +1710,6 @@ def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit__promote_expr(self) -# PEP 492 nodes: 'await', 'async for', 'async with'. -# ('async def' is a FuncDef with is_coroutine = True.) - - class AwaitExpr(Node): """Await expression (await ...).""" # TODO: [de]serialize() @@ -1727,50 +1725,6 @@ def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_await_expr(self) -class AsyncForStmt(Node): - """Asynchronous for statement (async for ...).""" - # TODO Maybe just use ForStmt with an extra flag? - # TODO: [de]serialize - # TODO: types - - # Index variables - index = None # type: Expression - # Expression to iterate - expr = None # type: Expression - body = None # type: Block - else_body = None # type: Block - - def __init__(self, index: Expression, expr: Expression, body: Block, - else_body: Block) -> None: - self.index = index - self.expr = expr - self.body = body - self.else_body = else_body - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_async_for_stmt(self) - - -class AsyncWithStmt(Node): - """Asynchronous with statement (async with ...).""" - # TODO Maybe just use WithStmt with an extra flag? - # TODO: [de]serialize - # TODO: types - - expr = None # type: List[Expression] - target = None # type: List[Expression] - body = None # type: Block - - def __init__(self, expr: List[Expression], target: List[Expression], - body: Block) -> None: - self.expr = expr - self.target = target - self.body = body - - def accept(self, visitor: NodeVisitor[T]) -> T: - return visitor.visit_async_with_stmt(self) - - # Constants diff --git a/mypy/semanal.py b/mypy/semanal.py index cff4ab2f4a29..dc84756fc4bc 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -62,10 +62,9 @@ ComparisonExpr, StarExpr, ARG_POS, ARG_NAMED, MroError, type_aliases, YieldFromExpr, NamedTupleExpr, NonlocalDecl, SetComprehension, DictionaryComprehension, TYPE_ALIAS, TypeAliasExpr, - YieldExpr, ExecStmt, Argument, BackquoteExpr, ImportBase, + YieldExpr, ExecStmt, Argument, BackquoteExpr, ImportBase, AwaitExpr, IntExpr, FloatExpr, UnicodeExpr, - AwaitExpr, AsyncForStmt, AsyncWithStmt, - COVARIANT, CONTRAVARIANT, INVARIANT, UNBOUND_IMPORTED + COVARIANT, CONTRAVARIANT, INVARIANT, UNBOUND_IMPORTED, ) from mypy.visitor import NodeVisitor from mypy.traverser import TraverserVisitor @@ -1687,19 +1686,6 @@ def visit_for_stmt(self, s: ForStmt) -> None: self.visit_block_maybe(s.else_body) - def visit_async_for_stmt(self, s: AsyncForStmt) -> None: - # TODO: type of expr. - s.expr.accept(self) - - # Bind index variables and check if they define new names. - self.analyze_lvalue(s.index) - - self.loop_depth += 1 - self.visit_block(s.body) - self.loop_depth -= 1 - - self.visit_block_maybe(s.else_body) - def visit_break_stmt(self, s: BreakStmt) -> None: if self.loop_depth == 0: self.fail("'break' outside loop", s, True, blocker=True) @@ -1739,14 +1725,6 @@ def visit_with_stmt(self, s: WithStmt) -> None: self.analyze_lvalue(n) self.visit_block(s.body) - def visit_async_with_stmt(self, s: AsyncWithStmt) -> None: - # TODO: Type??? - for e, n in zip(s.expr, s.target): - e.accept(self) - if n: - self.analyze_lvalue(n) - self.visit_block(s.body) - def visit_del_stmt(self, s: DelStmt) -> None: s.expr.accept(self) if not self.is_valid_del_target(s.expr): @@ -2512,26 +2490,12 @@ def visit_for_stmt(self, s: ForStmt) -> None: if s.else_body: s.else_body.accept(self) - def visit_async_for_stmt(self, s: AsyncForStmt) -> None: - # TODO: Type of s.expr - self.analyze_lvalue(s.index) - s.body.accept(self) - if s.else_body: - s.else_body.accept(self) - def visit_with_stmt(self, s: WithStmt) -> None: for n in s.target: if n: self.analyze_lvalue(n) s.body.accept(self) - def visit_async_with_stmt(self, s: AsyncWithStmt) -> None: - # TODO: Type ??? - for n in s.target: - if n: - self.analyze_lvalue(n) - s.body.accept(self) - def visit_decorator(self, d: Decorator) -> None: d.var._fullname = self.sem.qualified_name(d.var.name()) self.sem.add_symbol(d.var.name(), SymbolTableNode(GDEF, d.var), d) diff --git a/mypy/traverser.py b/mypy/traverser.py index f458fde9159a..ddd4ea7aaa42 100644 --- a/mypy/traverser.py +++ b/mypy/traverser.py @@ -9,7 +9,7 @@ UnaryExpr, ListExpr, TupleExpr, DictExpr, SetExpr, IndexExpr, GeneratorExpr, ListComprehension, ConditionalExpr, TypeApplication, FuncExpr, ComparisonExpr, OverloadedFuncDef, YieldFromExpr, - YieldExpr, AsyncForStmt, AsyncWithStmt, + YieldExpr ) @@ -84,13 +84,6 @@ def visit_for_stmt(self, o: ForStmt) -> None: if o.else_body: o.else_body.accept(self) - def visit_async_for_stmt(self, o: AsyncForStmt) -> None: - o.index.accept(self) - o.expr.accept(self) - o.body.accept(self) - if o.else_body: - o.else_body.accept(self) - def visit_return_stmt(self, o: ReturnStmt) -> None: if o.expr is not None: o.expr.accept(self) @@ -135,13 +128,6 @@ def visit_with_stmt(self, o: WithStmt) -> None: o.target[i].accept(self) o.body.accept(self) - def visit_async_with_stmt(self, o: AsyncWithStmt) -> None: - for i in range(len(o.expr)): - o.expr[i].accept(self) - if o.target[i] is not None: - o.target[i].accept(self) - o.body.accept(self) - def visit_member_expr(self, o: MemberExpr) -> None: o.expr.accept(self) diff --git a/mypy/treetransform.py b/mypy/treetransform.py index 6ef62ec3d637..f05232586b14 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -19,8 +19,7 @@ ComparisonExpr, TempNode, StarExpr, YieldFromExpr, NamedTupleExpr, NonlocalDecl, SetComprehension, DictionaryComprehension, ComplexExpr, TypeAliasExpr, EllipsisExpr, - YieldExpr, ExecStmt, Argument, BackquoteExpr, - AwaitExpr + YieldExpr, ExecStmt, Argument, BackquoteExpr, AwaitExpr, ) from mypy.types import Type, FunctionLike, Instance from mypy.visitor import NodeVisitor diff --git a/mypy/visitor.py b/mypy/visitor.py index a67e1fc02ae6..43e7c161ea6d 100644 --- a/mypy/visitor.py +++ b/mypy/visitor.py @@ -79,9 +79,6 @@ def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> T: def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T: pass - def visit_async_for_stmt(self, o: 'mypy.nodes.AsyncForStmt') -> T: - pass - def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> T: pass @@ -112,9 +109,6 @@ def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> T: def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> T: pass - def visit_async_with_stmt(self, o: 'mypy.nodes.AsyncWithStmt') -> T: - pass - def visit_print_stmt(self, o: 'mypy.nodes.PrintStmt') -> T: pass From c8ea54b8f72756f9f2b9c46d43c58e9a33812ed1 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 2 Jul 2016 10:01:17 -0700 Subject: [PATCH 04/28] Basic `async for` is working. --- mypy/checker.py | 31 ++++++++++++++++++++++++-- mypy/messages.py | 1 + mypy/strconv.py | 8 +++++-- test-data/unit/check-async-await.test | 10 +++++---- test-data/unit/fixtures/async_await.py | 4 +++- test-data/unit/lib-stub/typing.py | 9 ++++++++ 6 files changed, 54 insertions(+), 9 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index d3c75939fe4c..9f4b83af7078 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -287,7 +287,7 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: # - get_generator_return_type(t) returns tr. def is_generator_return_type(self, typ: Type) -> bool: - """Is `typ` a valid type for generator? + """Is `typ` a valid type for generator? True if either Generator or Awaitable is a supertype of `typ`. """ @@ -1659,10 +1659,37 @@ def visit_except_handler_test(self, n: Node) -> Type: def visit_for_stmt(self, s: ForStmt) -> Type: """Type check a for statement.""" - item_type = self.analyze_iterable_item_type(s.expr) + if s.is_async: + item_type = self.analyze_async_iterable_item_type(s.expr) + else: + item_type = self.analyze_iterable_item_type(s.expr) self.analyze_index_variables(s.index, item_type, s) self.accept_loop(s.body, s.else_body) + def analyze_async_iterable_item_type(self, expr: Node) -> Type: + """Analyse async iterable expression and return iterator item type.""" + iterable = self.accept(expr) + + self.check_not_void(iterable, expr) + + self.check_subtype(iterable, + self.named_generic_type('typing.AsyncIterable', + [AnyType()]), + expr, messages.ASYNC_ITERABLE_EXPECTED) + + echk = self.expr_checker + method = echk.analyze_external_member_access('__aiter__', iterable, expr) + iterator = echk.check_call(method, [], [], expr)[0] + method = echk.analyze_external_member_access('__anext__', iterator, expr) + awaitable = echk.check_call(method, [], [], expr)[0] + method = echk.analyze_external_member_access('__await__', awaitable, expr) + generator = echk.check_call(method, [], [], expr)[0] + if (isinstance(generator, Instance) and len(generator.args) == 3 + and generator.type.fullname() == 'typing.Generator'): + return generator.args[2] + else: + return AnyType() + def analyze_iterable_item_type(self, expr: Node) -> Type: """Analyse iterable expression and return iterator item type.""" iterable = self.accept(expr) diff --git a/mypy/messages.py b/mypy/messages.py index a78ccba8cbfd..b922449f94f9 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -58,6 +58,7 @@ INCOMPATIBLE_VALUE_TYPE = 'Incompatible dictionary value type' NEED_ANNOTATION_FOR_VAR = 'Need type annotation for variable' ITERABLE_EXPECTED = 'Iterable expected' +ASYNC_ITERABLE_EXPECTED = 'AsyncIterable expected' INCOMPATIBLE_TYPES_IN_FOR = 'Incompatible types in for statement' INCOMPATIBLE_ARRAY_VAR_ARGS = 'Incompatible variable arguments in call' INVALID_SLICE_INDEX = 'Slice index must be an integer or None' diff --git a/mypy/strconv.py b/mypy/strconv.py index e898a7e6c40f..cb48f8ec045c 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -199,8 +199,10 @@ def visit_while_stmt(self, o): return self.dump(a, o) def visit_for_stmt(self, o): - a = [o.index] - a.extend([o.expr, o.body]) + a = [] + if o.is_async: + a.append(('Async', '')) + a.extend([o.index, o.expr, o.body]) if o.else_body: a.append(('Else', o.else_body.body)) return self.dump(a, o) @@ -267,6 +269,8 @@ def visit_try_stmt(self, o): def visit_with_stmt(self, o): a = [] + if o.is_async: + a.append(('Async', '')) for i in range(len(o.expr)): a.append(('Expr', [o.expr[i]])) if o.target[i]: diff --git a/test-data/unit/check-async-await.test b/test-data/unit/check-async-await.test index 4cb97e934147..72e1c75483f6 100644 --- a/test-data/unit/check-async-await.test +++ b/test-data/unit/check-async-await.test @@ -58,10 +58,12 @@ main:6: error: Incompatible return value type (got "int", expected "str") [case testAsyncFor] # options: fast_parser -class C: - def __next__(self): return self - async def __aiter__(self) -> int: return 0 +from typing import AsyncIterator +class C(AsyncIterator[int]): + async def __anext__(self) -> int: return 0 async def f() -> None: async for x in C(): - pass # TODO: reveal or check type + reveal_type(x) # E: Revealed type is 'builtins.int*' [builtins fixtures/async_await.py] +[out] +main: note: In function "f": diff --git a/test-data/unit/fixtures/async_await.py b/test-data/unit/fixtures/async_await.py index 742cdb42fe38..4a00556bb597 100644 --- a/test-data/unit/fixtures/async_await.py +++ b/test-data/unit/fixtures/async_await.py @@ -1,5 +1,7 @@ import typing +class object: + def __init__(self): pass +class type: pass class function: pass class int: pass -class object: pass class str: pass diff --git a/test-data/unit/lib-stub/typing.py b/test-data/unit/lib-stub/typing.py index fe68df81279f..3e539f1f5e02 100644 --- a/test-data/unit/lib-stub/typing.py +++ b/test-data/unit/lib-stub/typing.py @@ -61,6 +61,15 @@ class Awaitable(Generic[T]): @abstractmethod def __await__(self) -> Generator[Any, Any, T]: pass +class AsyncIterable(Generic[T]): + @abstractmethod + def __aiter__(self) -> 'AsyncIterator[T]': pass + +class AsyncIterator(AsyncIterable[T], Generic[T]): + def __aiter__(self) -> 'AsyncIterator[T]': return self + @abstractmethod + def __anext__(self) -> Awaitable[T]: pass + class Sequence(Generic[T]): @abstractmethod def __getitem__(self, n: Any) -> T: pass From a8e9d7e11d2c0d10782bb5b580386ca61d0451f4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 3 Jul 2016 16:51:54 -0700 Subject: [PATCH 05/28] Clear unneeded TODOs. --- mypy/nodes.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mypy/nodes.py b/mypy/nodes.py index cbfaba725b63..217d55241f3f 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1712,9 +1712,6 @@ def accept(self, visitor: NodeVisitor[T]) -> T: class AwaitExpr(Node): """Await expression (await ...).""" - # TODO: [de]serialize() - # TODO: arg type must be AsyncIterator[E] or Generator[E, ..., ...] (or Iterator[E]?) - # and then the return type is E expr = None # type: Node From 41c5ed0bc5639c70ce213c874e2f2eecfae6d8d7 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 3 Jul 2016 17:40:46 -0700 Subject: [PATCH 06/28] Fledgeling `async with` support. --- mypy/checker.py | 31 ++++++++++++++++++++++----- test-data/unit/check-async-await.test | 12 +++++++++++ 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 9f4b83af7078..44bfe5dab9b5 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1684,10 +1684,12 @@ def analyze_async_iterable_item_type(self, expr: Node) -> Type: awaitable = echk.check_call(method, [], [], expr)[0] method = echk.analyze_external_member_access('__await__', awaitable, expr) generator = echk.check_call(method, [], [], expr)[0] + # XXX TODO Use get_generator_return_type()? if (isinstance(generator, Instance) and len(generator.args) == 3 and generator.type.fullname() == 'typing.Generator'): return generator.args[2] else: + # XXX TODO What if it's a subclass of Awaitable? return AnyType() def analyze_iterable_item_type(self, expr: Node) -> Type: @@ -1797,15 +1799,32 @@ def check_incompatible_property_override(self, e: Decorator) -> None: def visit_with_stmt(self, s: WithStmt) -> Type: echk = self.expr_checker + if s.is_async: + m_enter = '__aenter__' + m_exit = '__aexit__' + else: + m_enter = '__enter__' + m_exit = '__exit__' for expr, target in zip(s.expr, s.target): ctx = self.accept(expr) - enter = echk.analyze_external_member_access('__enter__', ctx, expr) + enter = echk.analyze_external_member_access(m_enter, ctx, expr) obj = echk.check_call(enter, [], [], expr)[0] + if s.is_async: + self.check_subtype(obj, self.named_type('typing.Awaitable'), expr) if target: + if s.is_async: + # XXX TODO What if it's a subclass of Awaitable? + if (isinstance(obj, Instance) and len(obj.args) == 1 + and obj.type.fullname() == 'typing.Awaitable'): + obj = obj.args[0] + else: + obj = AnyType() self.check_assignment(target, self.temp_node(obj, expr)) - exit = echk.analyze_external_member_access('__exit__', ctx, expr) + exit = echk.analyze_external_member_access(m_exit, ctx, expr) arg = self.temp_node(AnyType(), expr) - echk.check_call(exit, [arg] * 3, [nodes.ARG_POS] * 3, expr) + res = echk.check_call(exit, [arg] * 3, [nodes.ARG_POS] * 3, expr)[0] + if s.is_async: + self.check_subtype(res, self.named_type('typing.Awaitable'), expr) self.accept(s.body) def visit_print_stmt(self, s: PrintStmt) -> Type: @@ -2000,12 +2019,14 @@ def visit_await_expr(self, e: AwaitExpr) -> Type: if isinstance(actual_type, AnyType): return any_type if is_subtype(actual_type, generator_type): - if isinstance(actual_type, Instance) and len(actual_type.args) == 3: + if (isinstance(actual_type, Instance) and len(actual_type.args) == 3 + and actual_type.type.fullname() == 'typing.Generator'): return actual_type.args[2] else: return any_type # Must've been unparameterized Generator. elif is_subtype(actual_type, awaitable_type): - if isinstance(actual_type, Instance) and len(actual_type.args) == 1: + if (isinstance(actual_type, Instance) and len(actual_type.args) == 1 + and actual_type.type.fullname() == 'typing.Awaitable'): return actual_type.args[0] else: return any_type # Must've been unparameterized Awaitable. diff --git a/test-data/unit/check-async-await.test b/test-data/unit/check-async-await.test index 72e1c75483f6..3999f4ca3154 100644 --- a/test-data/unit/check-async-await.test +++ b/test-data/unit/check-async-await.test @@ -67,3 +67,15 @@ async def f() -> None: [builtins fixtures/async_await.py] [out] main: note: In function "f": + +[case testAsyncWith] +# options: fast_parser +class C: + async def __aenter__(self) -> int: pass + async def __aexit__(self, x, y, z) -> None: pass +async def f() -> None: + async with C() as x: + reveal_type(x) # E: Revealed type is 'builtins.int' +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": From 006d4b77bca9285471665ca4a0fbd0edfb85726c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 5 Jul 2016 11:42:26 -0700 Subject: [PATCH 07/28] Disallow `yield [from]` in `async def`. --- mypy/semanal.py | 10 +++++++-- test-data/unit/check-async-await.test | 30 ++++++++++++++++++++++++++ test-data/unit/fixtures/async_await.py | 1 + 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/mypy/semanal.py b/mypy/semanal.py index dc84756fc4bc..0b7d4e178cbb 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -1822,7 +1822,10 @@ def visit_yield_from_expr(self, e: YieldFromExpr) -> None: if not self.is_func_scope(): # not sure self.fail("'yield from' outside function", e, True, blocker=True) else: - self.function_stack[-1].is_generator = True + if self.function_stack[-1].is_coroutine: + self.fail("'yield from' in async function", e, True, blocker=True) + else: + self.function_stack[-1].is_generator = True if e.expr: e.expr.accept(self) @@ -2075,7 +2078,10 @@ def visit_yield_expr(self, expr: YieldExpr) -> None: if not self.is_func_scope(): self.fail("'yield' outside function", expr, True, blocker=True) else: - self.function_stack[-1].is_generator = True + if self.function_stack[-1].is_coroutine: + self.fail("'yield' in async function", expr, True, blocker=True) + else: + self.function_stack[-1].is_generator = True if expr.expr: expr.expr.accept(self) diff --git a/test-data/unit/check-async-await.test b/test-data/unit/check-async-await.test index 3999f4ca3154..c6820e6a773b 100644 --- a/test-data/unit/check-async-await.test +++ b/test-data/unit/check-async-await.test @@ -79,3 +79,33 @@ async def f() -> None: [builtins fixtures/async_await.py] [out] main: note: In function "f": + +[case testNoYieldInAsyncDef] +# options: fast_parser +async def f(): + yield None +async def g(): + yield +async def h(): + x = yield +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": +main:3: error: 'yield' in async function +main: note: In function "g": +main:5: error: 'yield' in async function +main: note: In function "h": +main:7: error: 'yield' in async function + +[case testNoYieldFromInAsyncDef] +# options: fast_parser +async def f(): + yield from [] +async def g(): + x = yield from [] +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": +main:3: error: 'yield from' in async function +main: note: In function "g": +main:5: error: 'yield from' in async function diff --git a/test-data/unit/fixtures/async_await.py b/test-data/unit/fixtures/async_await.py index 4a00556bb597..4bbc591da762 100644 --- a/test-data/unit/fixtures/async_await.py +++ b/test-data/unit/fixtures/async_await.py @@ -5,3 +5,4 @@ class type: pass class function: pass class int: pass class str: pass +class list: pass From eba9f0c41a7fd835dcdc0a6de9ff5058640661b9 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 6 Jul 2016 13:47:23 -0700 Subject: [PATCH 08/28] Check Python version before looking up typing.Awaitable. Ensure 'async def' is a syntax error in Python 2 (at least with the fast parser). --- mypy/checker.py | 6 +++++- test-data/unit/check-async-await.test | 5 +++++ test-data/unit/python2eval.test | 8 ++++++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index 44bfe5dab9b5..7c8f6f341112 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -292,8 +292,12 @@ def is_generator_return_type(self, typ: Type) -> bool: True if either Generator or Awaitable is a supertype of `typ`. """ gt = self.named_generic_type('typing.Generator', [AnyType(), AnyType(), AnyType()]) + if is_subtype(gt, typ): + return True + if self.options.python_version < (3, 5): + return False at = self.named_generic_type('typing.Awaitable', [AnyType()]) - return is_subtype(gt, typ) or is_subtype(at, typ) + return is_subtype(at, typ) def get_generator_yield_type(self, return_type: Type) -> Type: """Given the declared return type of a generator (t), return the type it yields (ty).""" diff --git a/test-data/unit/check-async-await.test b/test-data/unit/check-async-await.test index c6820e6a773b..ea33d9f9503a 100644 --- a/test-data/unit/check-async-await.test +++ b/test-data/unit/check-async-await.test @@ -109,3 +109,8 @@ main: note: In function "f": main:3: error: 'yield from' in async function main: note: In function "g": main:5: error: 'yield from' in async function + +[case testNoAsyncDefInPY2_python2] +# options: fast_parser +async def f(): # E: invalid syntax + pass diff --git a/test-data/unit/python2eval.test b/test-data/unit/python2eval.test index 4f9b633efc4d..944ce614eefa 100644 --- a/test-data/unit/python2eval.test +++ b/test-data/unit/python2eval.test @@ -440,3 +440,11 @@ re.subn(upat, u'', u'')[0] + u'' re.subn(ure, lambda m: u'', u'')[0] + u'' re.subn(upat, lambda m: u'', u'')[0] + u'' [out] + +[case testYieldRegressionTypingAwaitable_python2] +# Make sure we don't reference typing.Awaitable in Python 2 mode. +def g() -> int: + yield +[out] +_program.py: note: In function "g": +_program.py:2: error: The return type of a generator function should be "Generator" or one of its supertypes From 217116acfeae0c899ff56fb511162c9e4b4317f0 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 6 Jul 2016 14:50:50 -0700 Subject: [PATCH 09/28] Vast strides in accuracy for visit_await_expr(). --- mypy/checker.py | 24 +++++-------- test-data/unit/check-async-await.test | 51 +++++++++++++++++++++++---- 2 files changed, 53 insertions(+), 22 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 7c8f6f341112..b2a3027c5b54 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -2015,30 +2015,24 @@ def visit_yield_expr(self, e: YieldExpr) -> Type: return self.get_generator_receive_type(return_type) def visit_await_expr(self, e: AwaitExpr) -> Type: - any_type = AnyType() - awaitable_type = self.named_generic_type('typing.Awaitable', [any_type]) - generator_type = self.named_generic_type('typing.Generator', [any_type] * 3) - expected_type = UnionType([awaitable_type, generator_type]) + expected_type = self.type_context[-1] + if expected_type is not None: + expected_type = self.named_generic_type('typing.Awaitable', [expected_type]) actual_type = self.accept(e.expr, expected_type) if isinstance(actual_type, AnyType): - return any_type - if is_subtype(actual_type, generator_type): - if (isinstance(actual_type, Instance) and len(actual_type.args) == 3 - and actual_type.type.fullname() == 'typing.Generator'): - return actual_type.args[2] - else: - return any_type # Must've been unparameterized Generator. - elif is_subtype(actual_type, awaitable_type): + return AnyType() + awaitable_type = self.named_generic_type('typing.Awaitable', [AnyType()]) + if is_subtype(actual_type, awaitable_type): if (isinstance(actual_type, Instance) and len(actual_type.args) == 1 and actual_type.type.fullname() == 'typing.Awaitable'): return actual_type.args[0] else: - return any_type # Must've been unparameterized Awaitable. - msg = "{} (actual type {}, expected Awaitable or Generator)".format( + return AnyType() # Must've been unparameterized Awaitable. + msg = "{} (actual type {}, expected Awaitable)".format( messages.INCOMPATIBLE_TYPES_IN_AWAIT, self.msg.format(actual_type)) self.fail(msg, e) - return any_type + return AnyType() # # Helpers diff --git a/test-data/unit/check-async-await.test b/test-data/unit/check-async-await.test index ea33d9f9503a..5b84369c1a9c 100644 --- a/test-data/unit/check-async-await.test +++ b/test-data/unit/check-async-await.test @@ -21,16 +21,53 @@ async def f() -> int: return x [builtins fixtures/async_await.py] -[case testAwaitGenerator] +[case testAwaitDefaultContext] # options: fast_parser -from typing import Any, Generator -def g() -> Generator[int, int, int]: - x = yield 0 - return x +from typing import Any, TypeVar +T = TypeVar('T') +async def f(x: T) -> T: + y = await f(x) + reveal_type(y) + return y +[out] +main: note: In function "f": +main:6: error: Revealed type is 'T`-1' + +[case testAwaitAnyContext] +# options: fast_parser +from typing import Any, TypeVar +T = TypeVar('T') +async def f(x: T) -> T: + y = await f(x) # type: Any + reveal_type(y) + return y +[out] +main: note: In function "f": +main:6: error: Revealed type is 'Any' + +[case testAwaitExplicitContext] +# options: fast_parser +from typing import Any, TypeVar +T = TypeVar('T') +async def f(x: T) -> T: + y = await f(x) # type: int + reveal_type(y) +[out] +main: note: In function "f": +main:5: error: Argument 1 to "f" has incompatible type "T"; expected "int" +main:6: error: Revealed type is 'builtins.int' + +[case testAwaitGeneratorError] +# options: fast_parser +from typing import Any, Iterator +def g() -> Iterator[Any]: + yield async def f() -> int: x = await g() return x -[builtins fixtures/async_await.py] +[out] +main: note: In function "f": +main: error: Incompatible types in await (actual type Iterator[Any], expected Awaitable) [case testAwaitArgumentError] # options: fast_parser @@ -42,7 +79,7 @@ async def f() -> int: [builtins fixtures/async_await.py] [out] main: note: In function "f": -main: error: Incompatible types in await (actual type "int", expected Awaitable or Generator) +main: error: Incompatible types in await (actual type "int", expected Awaitable) [case testAwaitResultError] # options: fast_parser From f294d81c8eccd2ba7be445cfa549494eed6e1b8d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 11 Jul 2016 15:36:46 -0700 Subject: [PATCH 10/28] Add `@with_line` to PEP 492 visit function definitions. --- mypy/fastparse.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 0770adf712a4..e65cda8cb04e 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -769,15 +769,18 @@ def visit_Index(self, n: ast35.Index) -> Node: # AsyncFunctionDef(identifier name, arguments args, # stmt* body, expr* decorator_list, expr? returns, string? type_comment) + @with_line def visit_AsyncFunctionDef(self, n: ast35.AsyncFunctionDef) -> Node: return self.do_func_def(n, is_coroutine=True) # Await(expr value) + @with_line def visit_Await(self, n: ast35.Await) -> Node: v = self.visit(n.value) return AwaitExpr(v) # AsyncFor(expr target, expr iter, stmt* body, stmt* orelse) + @with_line def visit_AsyncFor(self, n: ast35.AsyncFor) -> Node: r = ForStmt(self.visit(n.target), self.visit(n.iter), @@ -787,6 +790,7 @@ def visit_AsyncFor(self, n: ast35.AsyncFor) -> Node: return r # AsyncWith(withitem* items, stmt* body) + @with_line def visit_AsyncWith(self, n: ast35.AsyncWith) -> Node: r = WithStmt([self.visit(i.context_expr) for i in n.items], [self.visit(i.optional_vars) for i in n.items], From 9e2a2ebafd416eb11edaa84795be45c3388932ac Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 14 Jul 2016 11:15:38 -0700 Subject: [PATCH 11/28] Fix tests now that errors have line numbers. --- test-data/unit/check-async-await.test | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test-data/unit/check-async-await.test b/test-data/unit/check-async-await.test index 5b84369c1a9c..f7318a2486cf 100644 --- a/test-data/unit/check-async-await.test +++ b/test-data/unit/check-async-await.test @@ -67,7 +67,7 @@ async def f() -> int: return x [out] main: note: In function "f": -main: error: Incompatible types in await (actual type Iterator[Any], expected Awaitable) +main:6: error: Incompatible types in await (actual type Iterator[Any], expected Awaitable) [case testAwaitArgumentError] # options: fast_parser @@ -79,7 +79,7 @@ async def f() -> int: [builtins fixtures/async_await.py] [out] main: note: In function "f": -main: error: Incompatible types in await (actual type "int", expected Awaitable) +main:5: error: Incompatible types in await (actual type "int", expected Awaitable) [case testAwaitResultError] # options: fast_parser From c077d32d6256e855cd1e634bd5c8ba9bef943404 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 19 Jul 2016 14:50:57 -0700 Subject: [PATCH 12/28] Tweak tests for async/await a bit. --- test-data/unit/check-async-await.test | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/test-data/unit/check-async-await.test b/test-data/unit/check-async-await.test index f7318a2486cf..893dc41de438 100644 --- a/test-data/unit/check-async-await.test +++ b/test-data/unit/check-async-await.test @@ -18,12 +18,15 @@ reveal_type(f()) # E: Revealed type is 'typing.Awaitable[builtins.int]' # options: fast_parser async def f() -> int: x = await f() + reveal_type(x) # E: Revealed type is 'builtins.int' return x [builtins fixtures/async_await.py] +[out] +main: note: In function "f": [case testAwaitDefaultContext] # options: fast_parser -from typing import Any, TypeVar +from typing import TypeVar T = TypeVar('T') async def f(x: T) -> T: y = await f(x) @@ -47,7 +50,7 @@ main:6: error: Revealed type is 'Any' [case testAwaitExplicitContext] # options: fast_parser -from typing import Any, TypeVar +from typing import TypeVar T = TypeVar('T') async def f(x: T) -> T: y = await f(x) # type: int @@ -57,7 +60,7 @@ main: note: In function "f": main:5: error: Argument 1 to "f" has incompatible type "T"; expected "int" main:6: error: Revealed type is 'builtins.int' -[case testAwaitGeneratorError] +[case testAwaitIteratorError] # options: fast_parser from typing import Any, Iterator def g() -> Iterator[Any]: @@ -83,6 +86,17 @@ main:5: error: Incompatible types in await (actual type "int", expected Awaitabl [case testAwaitResultError] # options: fast_parser +async def g() -> int: + return 0 +async def f() -> str: + x = await g() # type: str +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": +main:5: error: Incompatible types in assignment (expression has type "int", variable has type "str") + +[case testAwaitReturnError] +# options: fast_parser async def g() -> int: return 0 async def f() -> str: From c8d1cda01b677b917df55f82507055f3c48eb6f4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 19 Jul 2016 17:22:27 -0700 Subject: [PATCH 13/28] Get rid of remaining XXX issues. is_generator_return_type() now takes an extra is_coroutine flag. --- mypy/checker.py | 76 +++++++++++++++++++++++-------------------------- 1 file changed, 35 insertions(+), 41 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index b2a3027c5b54..1464be98b3cf 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -275,35 +275,35 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: # object or Any). If ts/tr are not given, both are Any. # # A coroutine must define a return type corresponding to tr; the - # other two are fixed: ty is AbstractFuture[tr], ts in Any. The - # "external" return type (seen by the caller) is Awaitable[tr]. + # other two are unconstrained. The "external" return type (seen + # by the caller) is Awaitable[tr]. # - # There are several useful methods: + # There are several useful methods, each taking a type t and a + # flag c indicating whether it's for a generator or coroutine: # - # - is_generator_return_type(t) returns whether t is a Generator, - # Iterator, Iterable, or Awaitable. - # - get_generator_yield_type(t) returns ty. - # - get_generator_receive_type(t) returns ts. - # - get_generator_return_type(t) returns tr. + # - is_generator_return_type(t, c) returns whether t is a Generator, + # Iterator, Iterable (if not c), or Awaitable (if c). + # - get_generator_yield_type(t, c) returns ty. + # - get_generator_receive_type(t, c) returns ts. + # - get_generator_return_type(t, c) returns tr. - def is_generator_return_type(self, typ: Type) -> bool: - """Is `typ` a valid type for generator? + def is_generator_return_type(self, typ: Type, is_coroutine: bool) -> bool: + """Is `typ` a valid type for a generator/coroutine? True if either Generator or Awaitable is a supertype of `typ`. """ - gt = self.named_generic_type('typing.Generator', [AnyType(), AnyType(), AnyType()]) - if is_subtype(gt, typ): - return True - if self.options.python_version < (3, 5): - return False - at = self.named_generic_type('typing.Awaitable', [AnyType()]) - return is_subtype(at, typ) + if is_coroutine: + at = self.named_generic_type('typing.Awaitable', [AnyType()]) + return is_subtype(at, typ) + else: + gt = self.named_generic_type('typing.Generator', [AnyType(), AnyType(), AnyType()]) + return is_subtype(gt, typ) - def get_generator_yield_type(self, return_type: Type) -> Type: + def get_generator_yield_type(self, return_type: Type, is_coroutine: bool) -> Type: """Given the declared return type of a generator (t), return the type it yields (ty).""" if isinstance(return_type, AnyType): return AnyType() - elif not self.is_generator_return_type(return_type): + elif not self.is_generator_return_type(return_type, is_coroutine): # If the function doesn't have a proper Generator (or superclass) return type, anything # is permissible. return AnyType() @@ -311,7 +311,7 @@ def get_generator_yield_type(self, return_type: Type) -> Type: # Same as above, but written as a separate branch so the typechecker can understand. return AnyType() elif return_type.type.fullname() == 'typing.Awaitable': - return AnyType() # XXX TODO: AbstractFuture[tr] + return AnyType() elif return_type.args: return return_type.args[0] else: @@ -320,11 +320,11 @@ def get_generator_yield_type(self, return_type: Type) -> Type: # be accessed so any type is acceptable. return AnyType() - def get_generator_receive_type(self, return_type: Type) -> Type: + def get_generator_receive_type(self, return_type: Type, is_coroutine: bool) -> Type: """Given a declared generator return type (t), return the type its yield receives (ts).""" if isinstance(return_type, AnyType): return AnyType() - elif not self.is_generator_return_type(return_type): + elif not self.is_generator_return_type(return_type, is_coroutine): # If the function doesn't have a proper Generator (or superclass) return type, anything # is permissible. return AnyType() @@ -346,11 +346,11 @@ def get_generator_receive_type(self, return_type: Type) -> Type: # values. return Void() - def get_generator_return_type(self, return_type: Type) -> Type: + def get_generator_return_type(self, return_type: Type, is_coroutine: bool) -> Type: """Given the declared return type of a generator (t), return the type it returns (tr).""" if isinstance(return_type, AnyType): return AnyType() - elif not self.is_generator_return_type(return_type): + elif not self.is_generator_return_type(return_type, is_coroutine): # If the function doesn't have a proper Generator (or superclass) return type, anything # is permissible. return AnyType() @@ -506,7 +506,7 @@ def is_implicit_any(t: Type) -> bool: # Check that Generator functions have the appropriate return type. if defn.is_generator: - if not self.is_generator_return_type(typ.ret_type): + if not self.is_generator_return_type(typ.ret_type, defn.is_coroutine): self.fail(messages.INVALID_RETURN_TYPE_FOR_GENERATOR, typ) # Python 2 generators aren't allowed to return values. @@ -1395,8 +1395,10 @@ def visit_return_stmt(self, s: ReturnStmt) -> Type: """Type check a return statement.""" self.binder.breaking_out = True if self.is_within_function(): - if self.function_stack[-1].is_generator: - return_type = self.get_generator_return_type(self.return_types[-1]) + defn = self.function_stack[-1] + if defn.is_generator: + return_type = self.get_generator_return_type(self.return_types[-1], + defn.is_coroutine) else: return_type = self.return_types[-1] @@ -1686,15 +1688,7 @@ def analyze_async_iterable_item_type(self, expr: Node) -> Type: iterator = echk.check_call(method, [], [], expr)[0] method = echk.analyze_external_member_access('__anext__', iterator, expr) awaitable = echk.check_call(method, [], [], expr)[0] - method = echk.analyze_external_member_access('__await__', awaitable, expr) - generator = echk.check_call(method, [], [], expr)[0] - # XXX TODO Use get_generator_return_type()? - if (isinstance(generator, Instance) and len(generator.args) == 3 - and generator.type.fullname() == 'typing.Generator'): - return generator.args[2] - else: - # XXX TODO What if it's a subclass of Awaitable? - return AnyType() + return self.get_generator_return_type(awaitable, True) def analyze_iterable_item_type(self, expr: Node) -> Type: """Analyse iterable expression and return iterator item type.""" @@ -1876,8 +1870,8 @@ def visit_yield_from_expr(self, e: YieldFromExpr) -> Type: # Check that the iterator's item type matches the type yielded by the Generator function # containing this `yield from` expression. - expected_item_type = self.get_generator_yield_type(return_type) - actual_item_type = self.get_generator_yield_type(iter_type) + expected_item_type = self.get_generator_yield_type(return_type, False) + actual_item_type = self.get_generator_yield_type(iter_type, False) self.check_subtype(actual_item_type, expected_item_type, e, messages.INCOMPATIBLE_TYPES_IN_YIELD_FROM, @@ -1886,7 +1880,7 @@ def visit_yield_from_expr(self, e: YieldFromExpr) -> Type: # Determine the type of the entire yield from expression. if (isinstance(iter_type, Instance) and iter_type.type.fullname() == 'typing.Generator'): - return self.get_generator_return_type(iter_type) + return self.get_generator_return_type(iter_type, False) else: # Non-Generators don't return anything from `yield from` expressions. return Void() @@ -2001,7 +1995,7 @@ def visit_backquote_expr(self, e: BackquoteExpr) -> Type: def visit_yield_expr(self, e: YieldExpr) -> Type: return_type = self.return_types[-1] - expected_item_type = self.get_generator_yield_type(return_type) + expected_item_type = self.get_generator_yield_type(return_type, False) if e.expr is None: if (not (isinstance(expected_item_type, Void) or isinstance(expected_item_type, AnyType)) @@ -2012,7 +2006,7 @@ def visit_yield_expr(self, e: YieldExpr) -> Type: self.check_subtype(actual_item_type, expected_item_type, e, messages.INCOMPATIBLE_TYPES_IN_YIELD, 'actual type', 'expected type') - return self.get_generator_receive_type(return_type) + return self.get_generator_receive_type(return_type, False) def visit_await_expr(self, e: AwaitExpr) -> Type: expected_type = self.type_context[-1] From b5b154be3dc257ed9e750a14aea0bd8eeaff72c7 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 19 Jul 2016 17:28:29 -0700 Subject: [PATCH 14/28] Move PEP 492 nodes back where they belong. --- mypy/fastparse.py | 64 +++++++++++++++++++++++------------------------ 1 file changed, 31 insertions(+), 33 deletions(-) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index e65cda8cb04e..1db7004f6713 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -247,6 +247,12 @@ def visit_Module(self, mod: ast35.Module) -> Node: def visit_FunctionDef(self, n: ast35.FunctionDef) -> Node: return self.do_func_def(n) + # AsyncFunctionDef(identifier name, arguments args, + # stmt* body, expr* decorator_list, expr? returns, string? type_comment) + @with_line + def visit_AsyncFunctionDef(self, n: ast35.AsyncFunctionDef) -> Node: + return self.do_func_def(n, is_coroutine=True) + def do_func_def(self, n: Union[ast35.FunctionDef, ast35.AsyncFunctionDef], is_coroutine: bool = False) -> Node: """Helper shared between visit_FunctionDef and visit_AsyncFunctionDef.""" @@ -427,6 +433,16 @@ def visit_For(self, n: ast35.For) -> Node: self.as_block(n.body, n.lineno), self.as_block(n.orelse, n.lineno)) + # AsyncFor(expr target, expr iter, stmt* body, stmt* orelse) + @with_line + def visit_AsyncFor(self, n: ast35.AsyncFor) -> Node: + r = ForStmt(self.visit(n.target), + self.visit(n.iter), + self.as_block(n.body, n.lineno), + self.as_block(n.orelse, n.lineno)) + r.is_async = True + return r + # While(expr test, stmt* body, stmt* orelse) @with_line def visit_While(self, n: ast35.While) -> Node: @@ -448,6 +464,15 @@ def visit_With(self, n: ast35.With) -> Node: [self.visit(i.optional_vars) for i in n.items], self.as_block(n.body, n.lineno)) + # AsyncWith(withitem* items, stmt* body) + @with_line + def visit_AsyncWith(self, n: ast35.AsyncWith) -> Node: + r = WithStmt([self.visit(i.context_expr) for i in n.items], + [self.visit(i.optional_vars) for i in n.items], + self.as_block(n.body, n.lineno)) + r.is_async = True + return r + # Raise(expr? exc, expr? cause) @with_line def visit_Raise(self, n: ast35.Raise) -> Node: @@ -633,6 +658,12 @@ def visit_GeneratorExp(self, n: ast35.GeneratorExp) -> GeneratorExpr: iters, ifs_list) + # Await(expr value) + @with_line + def visit_Await(self, n: ast35.Await) -> Node: + v = self.visit(n.value) + return AwaitExpr(v) + # Yield(expr? value) @with_line def visit_Yield(self, n: ast35.Yield) -> Node: @@ -765,39 +796,6 @@ def visit_ExtSlice(self, n: ast35.ExtSlice) -> Node: def visit_Index(self, n: ast35.Index) -> Node: return self.visit(n.value) - # PEP 492 nodes: 'async def', 'await', 'async for', 'async with'. - - # AsyncFunctionDef(identifier name, arguments args, - # stmt* body, expr* decorator_list, expr? returns, string? type_comment) - @with_line - def visit_AsyncFunctionDef(self, n: ast35.AsyncFunctionDef) -> Node: - return self.do_func_def(n, is_coroutine=True) - - # Await(expr value) - @with_line - def visit_Await(self, n: ast35.Await) -> Node: - v = self.visit(n.value) - return AwaitExpr(v) - - # AsyncFor(expr target, expr iter, stmt* body, stmt* orelse) - @with_line - def visit_AsyncFor(self, n: ast35.AsyncFor) -> Node: - r = ForStmt(self.visit(n.target), - self.visit(n.iter), - self.as_block(n.body, n.lineno), - self.as_block(n.orelse, n.lineno)) - r.is_async = True - return r - - # AsyncWith(withitem* items, stmt* body) - @with_line - def visit_AsyncWith(self, n: ast35.AsyncWith) -> Node: - r = WithStmt([self.visit(i.context_expr) for i in n.items], - [self.visit(i.optional_vars) for i in n.items], - self.as_block(n.body, n.lineno)) - r.is_async = True - return r - class TypeConverter(ast35.NodeTransformer): def __init__(self, line: int = -1) -> None: From 1d14cca08f5302370e43b91b254c10e4fccd6a5c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 19 Jul 2016 19:12:29 -0700 Subject: [PATCH 15/28] Respond to code review. --- mypy/nodes.py | 2 +- mypy/semanal.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/mypy/nodes.py b/mypy/nodes.py index 217d55241f3f..cc77c8b82c57 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -416,7 +416,7 @@ class FuncItem(FuncBase): # Is this an overload variant of function with more than one overload variant? is_overload = False is_generator = False # Contains a yield statement? - is_coroutine = False # Defined using 'async def'? + is_coroutine = False # Defined using 'async def' syntax? is_static = False # Uses @staticmethod? is_class = False # Uses @classmethod? # Variants of function with type variables with values expanded diff --git a/mypy/semanal.py b/mypy/semanal.py index 0b7d4e178cbb..ce9281de02fe 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -2090,8 +2090,7 @@ def visit_await_expr(self, expr: AwaitExpr) -> None: self.fail("'await' outside function", expr) elif not self.function_stack[-1].is_coroutine: self.fail("'await' outside coroutine ('async def')", expr) - if expr.expr: - expr.expr.accept(self) + expr.expr.accept(self) # # Helpers From b21aae98d8745f4d3354faf3a3785a48bbbb41ac Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 20 Jul 2016 11:40:11 -0700 Subject: [PATCH 16/28] Add tests expecting errors from async for/with. --- test-data/unit/check-async-await.test | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/test-data/unit/check-async-await.test b/test-data/unit/check-async-await.test index 893dc41de438..e5585da69e2e 100644 --- a/test-data/unit/check-async-await.test +++ b/test-data/unit/check-async-await.test @@ -119,6 +119,18 @@ async def f() -> None: [out] main: note: In function "f": +[case testAsyncForError] +# options: fast_parser +from typing import AsyncIterator +async def f() -> None: + async for x in [1]: + pass +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": +main:4: error: AsyncIterable expected +main:4: error: List[int] has no attribute "__aiter__" + [case testAsyncWith] # options: fast_parser class C: @@ -131,6 +143,20 @@ async def f() -> None: [out] main: note: In function "f": +[case testAsyncWithError] +# options: fast_parser +class C: + def __enter__(self) -> int: pass + def __exit__(self, x, y, z) -> None: pass +async def f() -> None: + async with C() as x: + pass +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": +main:6: error: "C" has no attribute "__aenter__"; maybe "__enter__"? +main:6: error: "C" has no attribute "__aexit__"; maybe "__exit__"? + [case testNoYieldInAsyncDef] # options: fast_parser async def f(): From 3a766cb40395294fec4984d5e15a340642fc529b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 20 Jul 2016 11:43:59 -0700 Subject: [PATCH 17/28] Test that await is an error. --- test-data/unit/check-async-await.test | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test-data/unit/check-async-await.test b/test-data/unit/check-async-await.test index e5585da69e2e..533e0ff01876 100644 --- a/test-data/unit/check-async-await.test +++ b/test-data/unit/check-async-await.test @@ -60,6 +60,19 @@ main: note: In function "f": main:5: error: Argument 1 to "f" has incompatible type "T"; expected "int" main:6: error: Revealed type is 'builtins.int' +[case testAwaitGeneratorError] +# options: fast_parser +from typing import Any, Generator +def g() -> Generator[int, None, str]: + yield 0 + return '' +async def f() -> int: + x = await g() + return x +[out] +main: note: In function "f": +main:7: error: Incompatible types in await (actual type Generator[int, None, str], expected Awaitable) + [case testAwaitIteratorError] # options: fast_parser from typing import Any, Iterator From 0746ade41d54c73d64f3af06c748322f75232c8b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 20 Jul 2016 13:09:21 -0700 Subject: [PATCH 18/28] Verify that `yield from` does not accept coroutines. This revealed a spurious error "Function does not return a value", fixed that. --- mypy/checker.py | 6 +++++- test-data/unit/check-async-await.test | 13 +++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index 1464be98b3cf..0a019fbd567a 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1883,7 +1883,11 @@ def visit_yield_from_expr(self, e: YieldFromExpr) -> Type: return self.get_generator_return_type(iter_type, False) else: # Non-Generators don't return anything from `yield from` expressions. - return Void() + # However special-case Any (which might be produced by an error). + if isinstance(actual_item_type, AnyType): + return AnyType() + else: + return Void() def visit_member_expr(self, e: MemberExpr) -> Type: return self.expr_checker.visit_member_expr(e) diff --git a/test-data/unit/check-async-await.test b/test-data/unit/check-async-await.test index 533e0ff01876..73ed4b50c87f 100644 --- a/test-data/unit/check-async-await.test +++ b/test-data/unit/check-async-await.test @@ -204,3 +204,16 @@ main:5: error: 'yield from' in async function # options: fast_parser async def f(): # E: invalid syntax pass + +[case testYieldFromNoAwaitable] +# options: fast_parser +from typing import Any, Generator +async def f() -> str: + return '' +def g() -> Generator[Any, None, str]: + x = yield from f() + return x +[builtins fixtures/async_await.py] +[out] +main: note: In function "g": +main:6: error: "yield from" can't be applied to Awaitable[str] From ebd9de5b274b7ed42572b3bba0611f658e96df36 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 20 Jul 2016 15:23:10 -0700 Subject: [PATCH 19/28] Disallow return value in generator declared as -> Iterator. --- mypy/checker.py | 5 +++-- test-data/unit/check-statements.test | 4 +++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 0a019fbd567a..3689952e1e30 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -267,7 +267,8 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: # # Classic generators can be parameterized with three types: # - ty is the yield type (the type of y in `yield y`) - # - ts is the type receive by yield (the type of s in `s = yield`) + # - ts is the type received by yield (the type of s in `s = yield`) + # (it's named `ts` after `send()`, since `tr` is `return`). # - tr is the return type (the type of r in `return r`) # # A classic generator must define a return type that's either @@ -374,7 +375,7 @@ def get_generator_return_type(self, return_type: Type, is_coroutine: bool) -> Ty else: # `return_type` is supertype of Generator, so callers won't be able to see the return # type when used in a `yield from` expression. - return AnyType() + return Void() def visit_func_def(self, defn: FuncDef) -> Type: """Type check a function definition.""" diff --git a/test-data/unit/check-statements.test b/test-data/unit/check-statements.test index 2812eeaadd5f..0be318fd0c2c 100644 --- a/test-data/unit/check-statements.test +++ b/test-data/unit/check-statements.test @@ -81,12 +81,14 @@ def f() -> Generator[int, None, None]: [out] main: note: In function "f": -[case testReturnInIterator] +[case testReturnInIteratorError] from typing import Iterator def f() -> Iterator[int]: yield 1 return "foo" [out] +main: note: In function "f": +main:4: error: No return value expected -- If statement -- ------------ From 3b687c56e0bcc3e26641de6661f4e2c37e01dba7 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 20 Jul 2016 15:27:05 -0700 Subject: [PATCH 20/28] Fix typo in comment. --- mypy/checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index 3689952e1e30..a4e8dace064c 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -339,7 +339,7 @@ def get_generator_receive_type(self, return_type: Type, is_coroutine: bool) -> T else: return AnyType() elif return_type.type.fullname() == 'typing.Awaitable': - # Generator is one of the two types which specify the type of values it can receive. + # Awaitable is one of the two types which specify the type of values it can receive. # According to the stub this is always `Any`. return AnyType() else: From 1bab7e0b39cf6597141aaa3ed7408c612a5903df Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 20 Jul 2016 18:08:54 -0700 Subject: [PATCH 21/28] Refactor visit_with_stmt() into separate helper methods for async and regular. Also Use get_generator_return_type() instead of manually unpacking the value. --- mypy/checker.py | 61 ++++++++++++++++++++++++------------------------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index a4e8dace064c..d619e4b1ff2a 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -11,7 +11,7 @@ from mypy.errors import Errors, report_internal_error from mypy.nodes import ( - SymbolTable, Node, MypyFile, Var, + SymbolTable, Node, MypyFile, Var, Expression, OverloadedFuncDef, FuncDef, FuncItem, FuncBase, TypeInfo, ClassDef, GDEF, Block, AssignmentStmt, NameExpr, MemberExpr, IndexExpr, TupleExpr, ListExpr, ExpressionStmt, ReturnStmt, IfStmt, @@ -1797,35 +1797,38 @@ def check_incompatible_property_override(self, e: Decorator) -> None: self.fail(messages.READ_ONLY_PROPERTY_OVERRIDES_READ_WRITE, e) def visit_with_stmt(self, s: WithStmt) -> Type: - echk = self.expr_checker - if s.is_async: - m_enter = '__aenter__' - m_exit = '__aexit__' - else: - m_enter = '__enter__' - m_exit = '__exit__' for expr, target in zip(s.expr, s.target): - ctx = self.accept(expr) - enter = echk.analyze_external_member_access(m_enter, ctx, expr) - obj = echk.check_call(enter, [], [], expr)[0] - if s.is_async: - self.check_subtype(obj, self.named_type('typing.Awaitable'), expr) - if target: - if s.is_async: - # XXX TODO What if it's a subclass of Awaitable? - if (isinstance(obj, Instance) and len(obj.args) == 1 - and obj.type.fullname() == 'typing.Awaitable'): - obj = obj.args[0] - else: - obj = AnyType() - self.check_assignment(target, self.temp_node(obj, expr)) - exit = echk.analyze_external_member_access(m_exit, ctx, expr) - arg = self.temp_node(AnyType(), expr) - res = echk.check_call(exit, [arg] * 3, [nodes.ARG_POS] * 3, expr)[0] if s.is_async: - self.check_subtype(res, self.named_type('typing.Awaitable'), expr) + self.check_async_with_item(expr, target) + else: + self.check_with_item(expr, target) self.accept(s.body) + def check_async_with_item(self, expr: Expression, target: Expression) -> None: + echk = self.expr_checker + ctx = self.accept(expr) + enter = echk.analyze_external_member_access('__aenter__', ctx, expr) + obj = echk.check_call(enter, [], [], expr)[0] + self.check_subtype(obj, self.named_type('typing.Awaitable'), expr) + if target: + obj = self.get_generator_return_type(obj, True) + self.check_assignment(target, self.temp_node(obj, expr)) + exit = echk.analyze_external_member_access('__aexit__', ctx, expr) + arg = self.temp_node(AnyType(), expr) + res = echk.check_call(exit, [arg] * 3, [nodes.ARG_POS] * 3, expr)[0] + self.check_subtype(res, self.named_type('typing.Awaitable'), expr) + + def check_with_item(self, expr: Expression, target: Expression) -> None: + echk = self.expr_checker + ctx = self.accept(expr) + enter = echk.analyze_external_member_access('__enter__', ctx, expr) + obj = echk.check_call(enter, [], [], expr)[0] + if target: + self.check_assignment(target, self.temp_node(obj, expr)) + exit = echk.analyze_external_member_access('__exit__', ctx, expr) + arg = self.temp_node(AnyType(), expr) + res = echk.check_call(exit, [arg] * 3, [nodes.ARG_POS] * 3, expr)[0] + def visit_print_stmt(self, s: PrintStmt) -> Type: for arg in s.args: self.accept(arg) @@ -2022,11 +2025,7 @@ def visit_await_expr(self, e: AwaitExpr) -> Type: return AnyType() awaitable_type = self.named_generic_type('typing.Awaitable', [AnyType()]) if is_subtype(actual_type, awaitable_type): - if (isinstance(actual_type, Instance) and len(actual_type.args) == 1 - and actual_type.type.fullname() == 'typing.Awaitable'): - return actual_type.args[0] - else: - return AnyType() # Must've been unparameterized Awaitable. + return self.get_generator_return_type(actual_type, True) msg = "{} (actual type {}, expected Awaitable)".format( messages.INCOMPATIBLE_TYPES_IN_AWAIT, self.msg.format(actual_type)) From adc32a2e580574c8b727d054c408dc3cbe5e4b34 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 20 Jul 2016 21:29:30 -0700 Subject: [PATCH 22/28] Fix lint error. Correct comment about default ts/tr. --- mypy/checker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index d619e4b1ff2a..736fee412070 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -273,7 +273,7 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: # # A classic generator must define a return type that's either # `Generator[ty, ts, tr]`, Iterator[ty], or Iterable[ty] (or - # object or Any). If ts/tr are not given, both are Any. + # object or Any). If ts/tr are not given, both are Void. # # A coroutine must define a return type corresponding to tr; the # other two are unconstrained. The "external" return type (seen @@ -1827,7 +1827,7 @@ def check_with_item(self, expr: Expression, target: Expression) -> None: self.check_assignment(target, self.temp_node(obj, expr)) exit = echk.analyze_external_member_access('__exit__', ctx, expr) arg = self.temp_node(AnyType(), expr) - res = echk.check_call(exit, [arg] * 3, [nodes.ARG_POS] * 3, expr)[0] + echk.check_call(exit, [arg] * 3, [nodes.ARG_POS] * 3, expr) def visit_print_stmt(self, s: PrintStmt) -> Type: for arg in s.args: From 649386e6fd3a724c942b68f2af66c58b15550d6f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 21 Jul 2016 08:46:44 -0700 Subject: [PATCH 23/28] Improve errors when __aenter__/__aexit__ are not async. With tests. --- mypy/checker.py | 8 +++-- mypy/messages.py | 2 ++ test-data/unit/check-async-await.test | 48 +++++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 2 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 736fee412070..ea7e3e56aaba 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1809,14 +1809,18 @@ def check_async_with_item(self, expr: Expression, target: Expression) -> None: ctx = self.accept(expr) enter = echk.analyze_external_member_access('__aenter__', ctx, expr) obj = echk.check_call(enter, [], [], expr)[0] - self.check_subtype(obj, self.named_type('typing.Awaitable'), expr) + self.check_subtype(obj, self.named_type('typing.Awaitable'), expr, + messages.INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AENTER, + 'actual type', 'expected type') if target: obj = self.get_generator_return_type(obj, True) self.check_assignment(target, self.temp_node(obj, expr)) exit = echk.analyze_external_member_access('__aexit__', ctx, expr) arg = self.temp_node(AnyType(), expr) res = echk.check_call(exit, [arg] * 3, [nodes.ARG_POS] * 3, expr)[0] - self.check_subtype(res, self.named_type('typing.Awaitable'), expr) + self.check_subtype(res, self.named_type('typing.Awaitable'), expr, + messages.INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AEXIT, + 'actual type', 'expected type') def check_with_item(self, expr: Expression, target: Expression) -> None: echk = self.expr_checker diff --git a/mypy/messages.py b/mypy/messages.py index b922449f94f9..14d5853fb1b4 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -41,6 +41,8 @@ INCOMPATIBLE_TYPES_IN_ASSIGNMENT = 'Incompatible types in assignment' INCOMPATIBLE_REDEFINITION = 'Incompatible redefinition' INCOMPATIBLE_TYPES_IN_AWAIT = 'Incompatible types in await' +INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AENTER = 'Incompatible types in "async with" for __aenter__' +INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AEXIT = 'Incompatible types in "async with" for __aexit__' INCOMPATIBLE_TYPES_IN_YIELD = 'Incompatible types in yield' INCOMPATIBLE_TYPES_IN_YIELD_FROM = 'Incompatible types in "yield from"' INCOMPATIBLE_TYPES_IN_STR_INTERPOLATION = 'Incompatible types in string interpolation' diff --git a/test-data/unit/check-async-await.test b/test-data/unit/check-async-await.test index 73ed4b50c87f..d6a5e023549d 100644 --- a/test-data/unit/check-async-await.test +++ b/test-data/unit/check-async-await.test @@ -170,6 +170,54 @@ main: note: In function "f": main:6: error: "C" has no attribute "__aenter__"; maybe "__enter__"? main:6: error: "C" has no attribute "__aexit__"; maybe "__exit__"? +[case testAsyncWithErrorBadAenter] +# options: fast_parser +class C: + def __aenter__(self) -> int: pass + async def __aexit__(self, x, y, z) -> None: pass +async def f() -> None: + async with C() as x: # E: Incompatible types in "async with" for __aenter__ (actual type "int", expected type "Awaitable") + pass +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": + +[case testAsyncWithErrorBadAenter2] +# options: fast_parser +class C: + def __aenter__(self) -> None: pass + async def __aexit__(self, x, y, z) -> None: pass +async def f() -> None: + async with C() as x: # E: "__aenter__" of "C" does not return a value + pass +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": + +[case testAsyncWithErrorBadAexit] +# options: fast_parser +class C: + async def __aenter__(self) -> int: pass + def __aexit__(self, x, y, z) -> int: pass +async def f() -> None: + async with C() as x: # E: Incompatible types in "async with" for __aexit__ (actual type "int", expected type "Awaitable") + pass +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": + +[case testAsyncWithErrorBadAexit2] +# options: fast_parser +class C: + async def __aenter__(self) -> int: pass + def __aexit__(self, x, y, z) -> None: pass +async def f() -> None: + async with C() as x: # E: "__aexit__" of "C" does not return a value + pass +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": + [case testNoYieldInAsyncDef] # options: fast_parser async def f(): From 607621c09c30b5cfb08fa1314771451ecc1ce55c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 21 Jul 2016 09:51:03 -0700 Subject: [PATCH 24/28] Refactor: move all extraction of T from Awaitable[T] to a single helper. --- mypy/checker.py | 26 +++++++++++--------------- test-data/unit/check-async-await.test | 6 +++--- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index ea7e3e56aaba..fd1b0d25e2ed 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -377,6 +377,12 @@ def get_generator_return_type(self, return_type: Type, is_coroutine: bool) -> Ty # type when used in a `yield from` expression. return Void() + def get_awaitable_return_type(self, t: Type, ctx: Context, msg: str) -> Type: + """Given a type t, verify that it is a subtype of Awaitable[tr] and return tr.""" + self.check_subtype(t, self.named_type('typing.Awaitable'), ctx, + msg, 'actual type', 'expected type') + return self.get_generator_return_type(t, True) + def visit_func_def(self, defn: FuncDef) -> Type: """Type check a function definition.""" self.check_func_item(defn, name=defn.name()) @@ -1809,18 +1815,15 @@ def check_async_with_item(self, expr: Expression, target: Expression) -> None: ctx = self.accept(expr) enter = echk.analyze_external_member_access('__aenter__', ctx, expr) obj = echk.check_call(enter, [], [], expr)[0] - self.check_subtype(obj, self.named_type('typing.Awaitable'), expr, - messages.INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AENTER, - 'actual type', 'expected type') + obj = self.get_awaitable_return_type( + obj, expr, messages.INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AENTER) if target: - obj = self.get_generator_return_type(obj, True) self.check_assignment(target, self.temp_node(obj, expr)) exit = echk.analyze_external_member_access('__aexit__', ctx, expr) arg = self.temp_node(AnyType(), expr) res = echk.check_call(exit, [arg] * 3, [nodes.ARG_POS] * 3, expr)[0] - self.check_subtype(res, self.named_type('typing.Awaitable'), expr, - messages.INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AEXIT, - 'actual type', 'expected type') + self.get_awaitable_return_type( + res, expr, messages.INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AEXIT) def check_with_item(self, expr: Expression, target: Expression) -> None: echk = self.expr_checker @@ -2027,14 +2030,7 @@ def visit_await_expr(self, e: AwaitExpr) -> Type: actual_type = self.accept(e.expr, expected_type) if isinstance(actual_type, AnyType): return AnyType() - awaitable_type = self.named_generic_type('typing.Awaitable', [AnyType()]) - if is_subtype(actual_type, awaitable_type): - return self.get_generator_return_type(actual_type, True) - msg = "{} (actual type {}, expected Awaitable)".format( - messages.INCOMPATIBLE_TYPES_IN_AWAIT, - self.msg.format(actual_type)) - self.fail(msg, e) - return AnyType() + return self.get_awaitable_return_type(actual_type, e, messages.INCOMPATIBLE_TYPES_IN_AWAIT) # # Helpers diff --git a/test-data/unit/check-async-await.test b/test-data/unit/check-async-await.test index d6a5e023549d..98bbcb14813d 100644 --- a/test-data/unit/check-async-await.test +++ b/test-data/unit/check-async-await.test @@ -71,7 +71,7 @@ async def f() -> int: return x [out] main: note: In function "f": -main:7: error: Incompatible types in await (actual type Generator[int, None, str], expected Awaitable) +main:7: error: Incompatible types in await (actual type Generator[int, None, str], expected type "Awaitable") [case testAwaitIteratorError] # options: fast_parser @@ -83,7 +83,7 @@ async def f() -> int: return x [out] main: note: In function "f": -main:6: error: Incompatible types in await (actual type Iterator[Any], expected Awaitable) +main:6: error: Incompatible types in await (actual type Iterator[Any], expected type "Awaitable") [case testAwaitArgumentError] # options: fast_parser @@ -95,7 +95,7 @@ async def f() -> int: [builtins fixtures/async_await.py] [out] main: note: In function "f": -main:5: error: Incompatible types in await (actual type "int", expected Awaitable) +main:5: error: Incompatible types in await (actual type "int", expected type "Awaitable") [case testAwaitResultError] # options: fast_parser From c79898bf1237f323511874fa1383471f6b06bc3c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 21 Jul 2016 10:38:24 -0700 Subject: [PATCH 25/28] Follow __await__ to extract t from subclass of Awaitable[t]. --- mypy/checker.py | 21 ++++++++++++----- mypy/messages.py | 2 ++ test-data/unit/check-async-await.test | 32 ++++++++++++++++++++++++-- test-data/unit/fixtures/async_await.py | 1 + 4 files changed, 48 insertions(+), 8 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index fd1b0d25e2ed..016d9d3341c0 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -379,9 +379,14 @@ def get_generator_return_type(self, return_type: Type, is_coroutine: bool) -> Ty def get_awaitable_return_type(self, t: Type, ctx: Context, msg: str) -> Type: """Given a type t, verify that it is a subtype of Awaitable[tr] and return tr.""" - self.check_subtype(t, self.named_type('typing.Awaitable'), ctx, - msg, 'actual type', 'expected type') - return self.get_generator_return_type(t, True) + if not self.check_subtype(t, self.named_type('typing.Awaitable'), ctx, + msg, 'actual type', 'expected type'): + return AnyType() + else: + echk = self.expr_checker + method = echk.analyze_external_member_access('__await__', t, ctx) + generator = echk.check_call(method, [], [], ctx)[0] + return self.get_generator_return_type(generator, False) def visit_func_def(self, defn: FuncDef) -> Type: """Type check a function definition.""" @@ -1695,7 +1700,8 @@ def analyze_async_iterable_item_type(self, expr: Node) -> Type: iterator = echk.check_call(method, [], [], expr)[0] method = echk.analyze_external_member_access('__anext__', iterator, expr) awaitable = echk.check_call(method, [], [], expr)[0] - return self.get_generator_return_type(awaitable, True) + return self.get_awaitable_return_type(awaitable, expr, + messages.INCOMPATIBLE_TYPES_IN_ASYNC_FOR) def analyze_iterable_item_type(self, expr: Node) -> Type: """Analyse iterable expression and return iterator item type.""" @@ -2039,10 +2045,12 @@ def visit_await_expr(self, e: AwaitExpr) -> Type: def check_subtype(self, subtype: Type, supertype: Type, context: Context, msg: str = messages.INCOMPATIBLE_TYPES, subtype_label: str = None, - supertype_label: str = None) -> None: + supertype_label: str = None) -> bool: """Generate an error if the subtype is not compatible with supertype.""" - if not is_subtype(subtype, supertype): + if is_subtype(subtype, supertype): + return True + else: if isinstance(subtype, Void): self.msg.does_not_return_value(subtype, context) else: @@ -2056,6 +2064,7 @@ def check_subtype(self, subtype: Type, supertype: Type, context: Context, if extra_info: msg += ' (' + ', '.join(extra_info) + ')' self.fail(msg, context) + return False def named_type(self, name: str) -> Instance: """Return an instance type with type given by the name and no diff --git a/mypy/messages.py b/mypy/messages.py index 14d5853fb1b4..b4240f006dc1 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -43,6 +43,8 @@ INCOMPATIBLE_TYPES_IN_AWAIT = 'Incompatible types in await' INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AENTER = 'Incompatible types in "async with" for __aenter__' INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AEXIT = 'Incompatible types in "async with" for __aexit__' +INCOMPATIBLE_TYPES_IN_ASYNC_FOR = 'Incompatible types in "async for"' + INCOMPATIBLE_TYPES_IN_YIELD = 'Incompatible types in yield' INCOMPATIBLE_TYPES_IN_YIELD_FROM = 'Incompatible types in "yield from"' INCOMPATIBLE_TYPES_IN_STR_INTERPOLATION = 'Incompatible types in string interpolation' diff --git a/test-data/unit/check-async-await.test b/test-data/unit/check-async-await.test index 98bbcb14813d..581e35d99957 100644 --- a/test-data/unit/check-async-await.test +++ b/test-data/unit/check-async-await.test @@ -18,7 +18,7 @@ reveal_type(f()) # E: Revealed type is 'typing.Awaitable[builtins.int]' # options: fast_parser async def f() -> int: x = await f() - reveal_type(x) # E: Revealed type is 'builtins.int' + reveal_type(x) # E: Revealed type is 'builtins.int*' return x [builtins fixtures/async_await.py] [out] @@ -151,7 +151,7 @@ class C: async def __aexit__(self, x, y, z) -> None: pass async def f() -> None: async with C() as x: - reveal_type(x) # E: Revealed type is 'builtins.int' + reveal_type(x) # E: Revealed type is 'builtins.int*' [builtins fixtures/async_await.py] [out] main: note: In function "f": @@ -265,3 +265,31 @@ def g() -> Generator[Any, None, str]: [out] main: note: In function "g": main:6: error: "yield from" can't be applied to Awaitable[str] + +[case testAwaitableSubclass] +# options: fast_parser +from typing import Any, AsyncIterator, Awaitable, Generator +class A(Awaitable[int]): + def __await__(self) -> Generator[Any, None, int]: + yield + return 0 +class C: + def __aenter__(self) -> A: + return A() + def __aexit__(self, *a) -> A: + return A() +class I(AsyncIterator[int]): + def __aiter__(self) -> 'I': + return self + def __anext__(self) -> A: + return A() +async def main() -> None: + x = await A() + reveal_type(x) # E: Revealed type is 'builtins.int' + async with C() as y: + reveal_type(y) # E: Revealed type is 'builtins.int' + async for z in I(): + reveal_type(z) # E: Revealed type is 'builtins.int' +[builtins fixtures/async_await.py] +[out] +main: note: In function "main": diff --git a/test-data/unit/fixtures/async_await.py b/test-data/unit/fixtures/async_await.py index 4bbc591da762..7a166a07294c 100644 --- a/test-data/unit/fixtures/async_await.py +++ b/test-data/unit/fixtures/async_await.py @@ -6,3 +6,4 @@ class function: pass class int: pass class str: pass class list: pass +class tuple: pass From 2e9a6a593c598317457ba565c10c71e8fdebec0e Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 23 Jul 2016 13:27:41 -0700 Subject: [PATCH 26/28] Make get_generator_return_type() default to AnyType() (i.e. as it was). --- mypy/checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index 016d9d3341c0..26d822c22d79 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -375,7 +375,7 @@ def get_generator_return_type(self, return_type: Type, is_coroutine: bool) -> Ty else: # `return_type` is supertype of Generator, so callers won't be able to see the return # type when used in a `yield from` expression. - return Void() + return AnyType() def get_awaitable_return_type(self, t: Type, ctx: Context, msg: str) -> Type: """Given a type t, verify that it is a subtype of Awaitable[tr] and return tr.""" From e32a914e51718790f7a34c50434edbe6ed2807ac Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 23 Jul 2016 14:01:48 -0700 Subject: [PATCH 27/28] Fix test to match reverting get_generator_return_type() to default to Any. --- test-data/unit/check-statements.test | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test-data/unit/check-statements.test b/test-data/unit/check-statements.test index 0be318fd0c2c..2812eeaadd5f 100644 --- a/test-data/unit/check-statements.test +++ b/test-data/unit/check-statements.test @@ -81,14 +81,12 @@ def f() -> Generator[int, None, None]: [out] main: note: In function "f": -[case testReturnInIteratorError] +[case testReturnInIterator] from typing import Iterator def f() -> Iterator[int]: yield 1 return "foo" [out] -main: note: In function "f": -main:4: error: No return value expected -- If statement -- ------------ From 6581ce7d69408f9dd6dfa9abf7e0d6e8ae41a144 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 26 Jul 2016 17:01:32 -0700 Subject: [PATCH 28/28] Rename get_awaitable_return_type() to check_awaitable_expr(), update docstring. --- mypy/checker.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 26d822c22d79..df622c47a4bd 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -360,7 +360,7 @@ def get_generator_return_type(self, return_type: Type, is_coroutine: bool) -> Ty return AnyType() elif return_type.type.fullname() == 'typing.Generator': # Generator is one of the two types which specify the type of values it returns into - # `yield from` expressions (using a `return` statements). + # `yield from` expressions (using a `return` statement). if len(return_type.args) == 3: return return_type.args[2] else: @@ -377,8 +377,11 @@ def get_generator_return_type(self, return_type: Type, is_coroutine: bool) -> Ty # type when used in a `yield from` expression. return AnyType() - def get_awaitable_return_type(self, t: Type, ctx: Context, msg: str) -> Type: - """Given a type t, verify that it is a subtype of Awaitable[tr] and return tr.""" + def check_awaitable_expr(self, t: Type, ctx: Context, msg: str) -> Type: + """Check the argument to `await` and extract the type of value. + + Also used by `async for` and `async with`. + """ if not self.check_subtype(t, self.named_type('typing.Awaitable'), ctx, msg, 'actual type', 'expected type'): return AnyType() @@ -1700,8 +1703,8 @@ def analyze_async_iterable_item_type(self, expr: Node) -> Type: iterator = echk.check_call(method, [], [], expr)[0] method = echk.analyze_external_member_access('__anext__', iterator, expr) awaitable = echk.check_call(method, [], [], expr)[0] - return self.get_awaitable_return_type(awaitable, expr, - messages.INCOMPATIBLE_TYPES_IN_ASYNC_FOR) + return self.check_awaitable_expr(awaitable, expr, + messages.INCOMPATIBLE_TYPES_IN_ASYNC_FOR) def analyze_iterable_item_type(self, expr: Node) -> Type: """Analyse iterable expression and return iterator item type.""" @@ -1821,14 +1824,14 @@ def check_async_with_item(self, expr: Expression, target: Expression) -> None: ctx = self.accept(expr) enter = echk.analyze_external_member_access('__aenter__', ctx, expr) obj = echk.check_call(enter, [], [], expr)[0] - obj = self.get_awaitable_return_type( + obj = self.check_awaitable_expr( obj, expr, messages.INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AENTER) if target: self.check_assignment(target, self.temp_node(obj, expr)) exit = echk.analyze_external_member_access('__aexit__', ctx, expr) arg = self.temp_node(AnyType(), expr) res = echk.check_call(exit, [arg] * 3, [nodes.ARG_POS] * 3, expr)[0] - self.get_awaitable_return_type( + self.check_awaitable_expr( res, expr, messages.INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AEXIT) def check_with_item(self, expr: Expression, target: Expression) -> None: @@ -2036,7 +2039,7 @@ def visit_await_expr(self, e: AwaitExpr) -> Type: actual_type = self.accept(e.expr, expected_type) if isinstance(actual_type, AnyType): return AnyType() - return self.get_awaitable_return_type(actual_type, e, messages.INCOMPATIBLE_TYPES_IN_AWAIT) + return self.check_awaitable_expr(actual_type, e, messages.INCOMPATIBLE_TYPES_IN_AWAIT) # # Helpers