diff --git a/misc/async_matrix.py b/misc/async_matrix.py new file mode 100644 index 000000000000..e9a758a229dc --- /dev/null +++ b/misc/async_matrix.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +"""Test various combinations of generators/coroutines. + +This was used to cross-check the errors in the test case +testFullCoroutineMatrix in test-data/unit/check-async-await.test. +""" + +import sys +from types import coroutine +from typing import Any, AsyncIterator, Awaitable, Generator, Iterator + +# The various things you might try to use in `await` or `yield from`. + +def plain_generator() -> Generator[str, None, int]: + yield 'a' + return 1 + +async def plain_coroutine() -> int: + return 1 + +@coroutine +def decorated_generator() -> Generator[str, None, int]: + yield 'a' + return 1 + +@coroutine +async def decorated_coroutine() -> int: + return 1 + +class It(Iterator[str]): + stop = False + def __iter__(self) -> 'It': + return self + def __next__(self) -> str: + if self.stop: + raise StopIteration('end') + else: + self.stop = True + return 'a' + +def other_iterator() -> It: + return It() + +class Aw(Awaitable[int]): + def __await__(self) -> Generator[str, Any, int]: + yield 'a' + return 1 + +def other_coroutine() -> Aw: + return Aw() + +# The various contexts in which `await` or `yield from` might occur. + +def plain_host_generator(func) -> Generator[str, None, None]: + yield 'a' + x = 0 + f = func() + try: + x = yield from f + finally: + try: + f.close() + except AttributeError: + pass + +async def plain_host_coroutine(func) -> None: + x = 0 + x = await func() + +@coroutine +def decorated_host_generator(func) -> Generator[str, None, None]: + yield 'a' + x = 0 + f = func() + try: + x = yield from f + finally: + try: + f.close() + except AttributeError: + pass + +@coroutine +async def decorated_host_coroutine(func) -> None: + x = 0 + x = await func() + +# Main driver. + +def main(): + verbose = ('-v' in sys.argv) + for host in [plain_host_generator, plain_host_coroutine, + decorated_host_generator, decorated_host_coroutine]: + print() + print("==== Host:", host.__name__) + for func in [plain_generator, plain_coroutine, + decorated_generator, decorated_coroutine, + other_iterator, other_coroutine]: + print(" ---- Func:", func.__name__) + try: + f = host(func) + for i in range(10): + try: + x = f.send(None) + if verbose: + print(" yield:", x) + except StopIteration as e: + if verbose: + print(" stop:", e.value) + break + else: + if verbose: + print(" ???? still going") + except Exception as e: + print(" error:", repr(e)) + +# Run main(). + +if __name__ == '__main__': + main() diff --git a/mypy/checker.py b/mypy/checker.py index d959f73b8212..c8db99f596d3 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -265,54 +265,67 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: # in PEP 492 and only available in Python >= 3.5. # # Classic generators can be parameterized with three types: - # - ty is the yield type (the type of y in `yield y`) - # - ts is the type received by yield (the type of s in `s = yield`) - # (it's named `ts` after `send()`, since `tr` is `return`). - # - tr is the return type (the type of r in `return r`) + # - ty is the Yield type (the type of y in `yield y`) + # - tc is the type reCeived by yield (the type of c in `c = yield`). + # - tr is the Return type (the type of r in `return r`) # # A classic generator must define a return type that's either - # `Generator[ty, ts, tr]`, Iterator[ty], or Iterable[ty] (or - # object or Any). If ts/tr are not given, both are Void. + # `Generator[ty, tc, tr]`, Iterator[ty], or Iterable[ty] (or + # object or Any). If tc/tr are not given, both are Void. # # A coroutine must define a return type corresponding to tr; the # other two are unconstrained. The "external" return type (seen # by the caller) is Awaitable[tr]. # + # In addition, there's the synthetic type AwaitableGenerator: it + # inherits from both Awaitable and Generator and can be used both + # in `yield from` and in `await`. This type is set automatically + # for functions decorated with `@types.coroutine` or + # `@asyncio.coroutine`. Its single parameter corresponds to tr. + # # There are several useful methods, each taking a type t and a # flag c indicating whether it's for a generator or coroutine: # # - is_generator_return_type(t, c) returns whether t is a Generator, - # Iterator, Iterable (if not c), or Awaitable (if c). + # Iterator, Iterable (if not c), or Awaitable (if c), or + # AwaitableGenerator (regardless of c). # - get_generator_yield_type(t, c) returns ty. - # - get_generator_receive_type(t, c) returns ts. + # - get_generator_receive_type(t, c) returns tc. # - get_generator_return_type(t, c) returns tr. def is_generator_return_type(self, typ: Type, is_coroutine: bool) -> bool: """Is `typ` a valid type for a generator/coroutine? - True if either Generator or Awaitable is a supertype of `typ`. + True if `typ` is a *supertype* of Generator or Awaitable. + Also true it it's *exactly* AwaitableGenerator (modulo type parameters). """ if is_coroutine: + # This means we're in Python 3.5 or later. at = self.named_generic_type('typing.Awaitable', [AnyType()]) - return is_subtype(at, typ) + if is_subtype(at, typ): + return True else: gt = self.named_generic_type('typing.Generator', [AnyType(), AnyType(), AnyType()]) - return is_subtype(gt, typ) + if is_subtype(gt, typ): + return True + return isinstance(typ, Instance) and typ.type.fullname() == 'typing.AwaitableGenerator' def get_generator_yield_type(self, return_type: Type, is_coroutine: bool) -> Type: """Given the declared return type of a generator (t), return the type it yields (ty).""" if isinstance(return_type, AnyType): return AnyType() elif not self.is_generator_return_type(return_type, is_coroutine): - # If the function doesn't have a proper Generator (or superclass) return type, anything - # is permissible. + # If the function doesn't have a proper Generator (or + # Awaitable) return type, anything is permissible. return AnyType() elif not isinstance(return_type, Instance): # Same as above, but written as a separate branch so the typechecker can understand. return AnyType() elif return_type.type.fullname() == 'typing.Awaitable': + # Awaitable: ty is Any. return AnyType() elif return_type.args: + # AwaitableGenerator, Generator, Iterator, or Iterable; ty is args[0]. ret_type = return_type.args[0] # TODO not best fix, better have dedicated yield token if isinstance(ret_type, NoneTyp): @@ -324,33 +337,31 @@ def get_generator_yield_type(self, return_type: Type, is_coroutine: bool) -> Typ else: # If the function's declared supertype of Generator has no type # parameters (i.e. is `object`), then the yielded values can't - # be accessed so any type is acceptable. + # be accessed so any type is acceptable. IOW, ty is Any. + # (However, see https://github.com/python/mypy/issues/1933) return AnyType() def get_generator_receive_type(self, return_type: Type, is_coroutine: bool) -> Type: - """Given a declared generator return type (t), return the type its yield receives (ts).""" + """Given a declared generator return type (t), return the type its yield receives (tc).""" if isinstance(return_type, AnyType): return AnyType() elif not self.is_generator_return_type(return_type, is_coroutine): - # If the function doesn't have a proper Generator (or superclass) return type, anything - # is permissible. + # If the function doesn't have a proper Generator (or + # Awaitable) return type, anything is permissible. return AnyType() elif not isinstance(return_type, Instance): # Same as above, but written as a separate branch so the typechecker can understand. return AnyType() - elif return_type.type.fullname() == 'typing.Generator': - # Generator is one of the two types which specify the type of values it can receive. - if len(return_type.args) == 3: - return return_type.args[1] - else: - return AnyType() elif return_type.type.fullname() == 'typing.Awaitable': - # Awaitable is one of the two types which specify the type of values it can receive. - # According to the stub this is always `Any`. + # Awaitable, AwaitableGenerator: tc is Any. return AnyType() + elif (return_type.type.fullname() in ('typing.Generator', 'typing.AwaitableGenerator') + and len(return_type.args) >= 3): + # Generator: tc is args[1]. + return return_type.args[1] else: # `return_type` is a supertype of Generator, so callers won't be able to send it - # values. + # values. IOW, tc is None. if experiments.STRICT_OPTIONAL: return NoneTyp(is_ret_type=True) else: @@ -361,29 +372,21 @@ def get_generator_return_type(self, return_type: Type, is_coroutine: bool) -> Ty if isinstance(return_type, AnyType): return AnyType() elif not self.is_generator_return_type(return_type, is_coroutine): - # If the function doesn't have a proper Generator (or superclass) return type, anything - # is permissible. + # If the function doesn't have a proper Generator (or + # Awaitable) return type, anything is permissible. return AnyType() elif not isinstance(return_type, Instance): # Same as above, but written as a separate branch so the typechecker can understand. return AnyType() - elif return_type.type.fullname() == 'typing.Generator': - # Generator is one of the two types which specify the type of values it returns into - # `yield from` expressions (using a `return` statement). - if len(return_type.args) == 3: - return return_type.args[2] - else: - return AnyType() - elif return_type.type.fullname() == 'typing.Awaitable': - # Awaitable is the other type which specifies the type of values it returns into - # `yield from` expressions (using `return`). - if len(return_type.args) == 1: - return return_type.args[0] - else: - return AnyType() + elif return_type.type.fullname() == 'typing.Awaitable' and len(return_type.args) == 1: + # Awaitable: tr is args[0]. + return return_type.args[0] + elif (return_type.type.fullname() in ('typing.Generator', 'typing.AwaitableGenerator') + and len(return_type.args) >= 3): + # AwaitableGenerator, Generator: tr is args[2]. + return return_type.args[2] else: - # `return_type` is supertype of Generator, so callers won't be able to see the return - # type when used in a `yield from` expression. + # Supertype of Generator (Iterator, Iterable, object): tr is any. return AnyType() def check_awaitable_expr(self, t: Type, ctx: Context, msg: str) -> Type: @@ -540,6 +543,20 @@ def is_implicit_any(t: Type) -> bool: if not isinstance(typ.ret_type.args[2], (Void, NoneTyp, AnyType)): self.fail(messages.INVALID_GENERATOR_RETURN_ITEM_TYPE, typ) + # Fix the type if decorated with `@types.coroutine` or `@asyncio.coroutine`. + if defn.is_awaitable_coroutine: + # Update the return type to AwaitableGenerator. + # (This doesn't exist in typing.py, only in typing.pyi.) + t = typ.ret_type + c = defn.is_coroutine + ty = self.get_generator_yield_type(t, c) + tc = self.get_generator_receive_type(t, c) + tr = self.get_generator_return_type(t, c) + ret_type = self.named_generic_type('typing.AwaitableGenerator', + [ty, tc, tr, t]) + typ = typ.copy_modified(ret_type=ret_type) + defn.type = typ + # Push return type. self.return_types.append(typ.ret_type) @@ -1872,6 +1889,11 @@ def visit_call_expr(self, e: CallExpr) -> Type: return self.expr_checker.visit_call_expr(e) def visit_yield_from_expr(self, e: YieldFromExpr) -> Type: + # NOTE: Whether `yield from` accepts an `async def` decorated + # with `@types.coroutine` (or `@asyncio.coroutine`) depends on + # whether the generator containing the `yield from` is itself + # thus decorated. But it accepts a generator regardless of + # how it's decorated. return_type = self.return_types[-1] subexpr_type = self.accept(e.expr, return_type) iter_type = None # type: Type @@ -1882,6 +1904,8 @@ def visit_yield_from_expr(self, e: YieldFromExpr) -> Type: iter_type = AnyType() elif (isinstance(subexpr_type, Instance) and is_subtype(subexpr_type, self.named_type('typing.Iterable'))): + if self.is_async_def(subexpr_type) and not self.has_coroutine_decorator(return_type): + self.msg.yield_from_invalid_operand_type(subexpr_type, e) iter_method_type = self.expr_checker.analyze_external_member_access( '__iter__', subexpr_type, @@ -1892,8 +1916,12 @@ def visit_yield_from_expr(self, e: YieldFromExpr) -> Type: iter_type, _ = self.expr_checker.check_call(iter_method_type, [], [], context=generic_generator_type) else: - self.msg.yield_from_invalid_operand_type(subexpr_type, e) - iter_type = AnyType() + if not (self.is_async_def(subexpr_type) and self.has_coroutine_decorator(return_type)): + self.msg.yield_from_invalid_operand_type(subexpr_type, e) + iter_type = AnyType() + else: + iter_type = self.check_awaitable_expr(subexpr_type, e, + messages.INCOMPATIBLE_TYPES_IN_YIELD_FROM) # Check that the iterator's item type matches the type yielded by the Generator function # containing this `yield from` expression. @@ -1919,6 +1947,30 @@ def visit_yield_from_expr(self, e: YieldFromExpr) -> Type: else: return Void() + def has_coroutine_decorator(self, t: Type) -> bool: + """Whether t came from a function decorated with `@coroutine`.""" + return isinstance(t, Instance) and t.type.fullname() == 'typing.AwaitableGenerator' + + def is_async_def(self, t: Type) -> bool: + """Whether t came from a function defined using `async def`.""" + # In check_func_def(), when we see a function decorated with + # `@typing.coroutine` or `@async.coroutine`, we change the + # return type to typing.AwaitableGenerator[...], so that its + # type is compatible with either Generator or Awaitable. + # But for the check here we need to know whether the original + # function (before decoration) was an `async def`. The + # AwaitableGenerator type conveniently preserves the original + # type as its 4th parameter (3rd when using 0-origin indexing + # :-), so that we can recover that information here. + # (We really need to see whether the original, undecorated + # function was an `async def`, which is orthogonal to its + # decorations.) + if (isinstance(t, Instance) + and t.type.fullname() == 'typing.AwaitableGenerator' + and len(t.args) >= 4): + t = t.args[3] + return isinstance(t, Instance) and t.type.fullname() == 'typing.Awaitable' + def visit_member_expr(self, e: MemberExpr) -> Type: return self.expr_checker.visit_member_expr(e) diff --git a/mypy/nodes.py b/mypy/nodes.py index 21308e8d798a..bc25b8441045 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -417,6 +417,7 @@ class FuncItem(FuncBase): is_overload = False is_generator = False # Contains a yield statement? is_coroutine = False # Defined using 'async def' syntax? + is_awaitable_coroutine = False # Decorated with '@{typing,asyncio}.coroutine'? is_static = False # Uses @staticmethod? is_class = False # Uses @classmethod? # Variants of function with type variables with values expanded diff --git a/mypy/semanal.py b/mypy/semanal.py index 841d34739e5f..fa5bced17dbd 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -1716,8 +1716,10 @@ def visit_decorator(self, dec: Decorator) -> None: removed.append(i) dec.func.is_abstract = True self.check_decorated_function_is_method('abstractmethod', dec) - elif refers_to_fullname(d, 'asyncio.tasks.coroutine'): + elif (refers_to_fullname(d, 'asyncio.coroutines.coroutine') or + refers_to_fullname(d, 'types.coroutine')): removed.append(i) + dec.func.is_awaitable_coroutine = True elif refers_to_fullname(d, 'builtins.staticmethod'): removed.append(i) dec.func.is_static = True diff --git a/runtests.py b/runtests.py index 17a7384f24dc..47db9407798c 100755 --- a/runtests.py +++ b/runtests.py @@ -274,7 +274,7 @@ def add_samples(driver: Driver) -> None: 'import mypy.codec.register, %s' % bf, cwd=cwd) else: - driver.add_mypy('file %s' % f, f) + driver.add_mypy('file %s' % f, f, '--fast-parser') def usage(status: int) -> None: diff --git a/test-data/samples/crawl.py b/test-data/samples/crawl.py index c7e8ab549746..fe15447ea36c 100644 --- a/test-data/samples/crawl.py +++ b/test-data/samples/crawl.py @@ -20,10 +20,9 @@ ARGS = argparse.ArgumentParser(description="Web crawler") -if sys.platform == 'win32': - ARGS.add_argument( - '--iocp', action='store_true', dest='iocp', - default=False, help='Use IOCP event loop (Windows only)') +ARGS.add_argument( + '--iocp', action='store_true', dest='iocp', + default=False, help='Use IOCP event loop (Windows only)') ARGS.add_argument( '--select', action='store_true', dest='select', default=False, help='Use Select event loop instead of default') @@ -592,11 +591,12 @@ def report(self, stats: 'Stats', file: IO[str] = None) -> None: stats.add('html') size = len(self.body or b'') stats.add('html_bytes', size) - print(self.url, self.response.status, - self.ctype, self.encoding, - size, - '%d/%d' % (len(self.new_urls or ()), len(self.urls or ())), - file=file) + if self.log.level: + print(self.url, self.response.status, + self.ctype, self.encoding, + size, + '%d/%d' % (len(self.new_urls or ()), len(self.urls or ())), + file=file) elif self.response is None: print(self.url, 'no response object') else: diff --git a/test-data/samples/crawl2.py b/test-data/samples/crawl2.py new file mode 100644 index 000000000000..a9f28a474ff6 --- /dev/null +++ b/test-data/samples/crawl2.py @@ -0,0 +1,852 @@ +#!/usr/bin/env python3.4 + +"""A simple web crawler.""" + +# This is cloned from /examples/crawl.py, +# with type annotations added (PEP 484). +# +# This version (crawl2.) has also been converted to use `async def` + +# `await` (PEP 492). + +import argparse +import asyncio +import cgi +from http.client import BadStatusLine +import logging +import re +import sys +import time +import urllib.parse +from typing import Any, Awaitable, IO, Optional, Sequence, Set, Tuple + + +ARGS = argparse.ArgumentParser(description="Web crawler") +ARGS.add_argument( + '--iocp', action='store_true', dest='iocp', + default=False, help='Use IOCP event loop (Windows only)') +ARGS.add_argument( + '--select', action='store_true', dest='select', + default=False, help='Use Select event loop instead of default') +ARGS.add_argument( + 'roots', nargs='*', + default=[], help='Root URL (may be repeated)') +ARGS.add_argument( + '--max_redirect', action='store', type=int, metavar='N', + default=10, help='Limit redirection chains (for 301, 302 etc.)') +ARGS.add_argument( + '--max_tries', action='store', type=int, metavar='N', + default=4, help='Limit retries on network errors') +ARGS.add_argument( + '--max_tasks', action='store', type=int, metavar='N', + default=100, help='Limit concurrent connections') +ARGS.add_argument( + '--max_pool', action='store', type=int, metavar='N', + default=100, help='Limit connection pool size') +ARGS.add_argument( + '--exclude', action='store', metavar='REGEX', + help='Exclude matching URLs') +ARGS.add_argument( + '--strict', action='store_true', + default=True, help='Strict host matching (default)') +ARGS.add_argument( + '--lenient', action='store_false', dest='strict', + default=False, help='Lenient host matching') +ARGS.add_argument( + '-v', '--verbose', action='count', dest='level', + default=1, help='Verbose logging (repeat for more verbose)') +ARGS.add_argument( + '-q', '--quiet', action='store_const', const=0, dest='level', + default=1, help='Quiet logging (opposite of --verbose)') + + +ESCAPES = [('quot', '"'), + ('gt', '>'), + ('lt', '<'), + ('amp', '&') # Must be last. + ] + + +def unescape(url: str) -> str: + """Turn & into &, and so on. + + This is the inverse of cgi.escape(). + """ + for name, char in ESCAPES: + url = url.replace('&' + name + ';', char) + return url + + +def fix_url(url: str) -> str: + """Prefix a schema-less URL with http://.""" + if '://' not in url: + url = 'http://' + url + return url + + +class Logger: + + def __init__(self, level: int) -> None: + self.level = level + + def _log(self, n: int, args: Sequence[Any]) -> None: + if self.level >= n: + print(*args, file=sys.stderr, flush=True) + + def log(self, n: int, *args: Any) -> None: + self._log(n, args) + + def __call__(self, n: int, *args: Any) -> None: + self._log(n, args) + + +KeyTuple = Tuple[str, int, bool] + + +class ConnectionPool: + """A connection pool. + + To open a connection, use reserve(). To recycle it, use unreserve(). + + The pool is mostly just a mapping from (host, port, ssl) tuples to + lists of Connections. The currently active connections are *not* + in the data structure; get_connection() takes the connection out, + and recycle_connection() puts it back in. To recycle a + connection, call conn.close(recycle=True). + + There are limits to both the overall pool and the per-key pool. + """ + + def __init__(self, log: Logger, max_pool: int = 10, max_tasks: int = 5) -> None: + self.log = log + self.max_pool = max_pool # Overall limit. + self.max_tasks = max_tasks # Per-key limit. + self.loop = asyncio.get_event_loop() + self.connections = {} # type: Dict[KeyTuple, List[Connection]] + self.queue = [] # type: List[Connection] + + def close(self) -> None: + """Close all connections available for reuse.""" + for conns in self.connections.values(): + for conn in conns: + conn.close() + self.connections.clear() + self.queue.clear() + + async def get_connection(self, host: str, port: int, ssl: bool) -> 'Connection': + """Create or reuse a connection.""" + port = port or (443 if ssl else 80) + try: + ipaddrs = await self.loop.getaddrinfo(host, port) + except Exception as exc: + self.log(0, 'Exception %r for (%r, %r)' % (exc, host, port)) + raise + self.log(1, '* %s resolves to %s' % + (host, ', '.join(ip[4][0] for ip in ipaddrs))) + + # Look for a reusable connection. + for _1, _2, _3, _4, (h, p, *_5) in ipaddrs: + key = h, p, ssl + conn = None + conns = self.connections.get(key) + while conns: + conn = conns.pop(0) + self.queue.remove(conn) + if not conns: + del self.connections[key] + if conn.stale(): + self.log(1, 'closing stale connection for', key) + conn.close() # Just in case. + else: + self.log(1, '* Reusing pooled connection', key, + 'FD =', conn.fileno()) + return conn + + # Create a new connection. + conn = Connection(self.log, self, host, port, ssl) + await conn.connect() + self.log(1, '* New connection', conn.key, 'FD =', conn.fileno()) + return conn + + def recycle_connection(self, conn: 'Connection') -> None: + """Make a connection available for reuse. + + This also prunes the pool if it exceeds the size limits. + """ + if conn.stale(): + conn.close() + return + + key = conn.key + conns = self.connections.setdefault(key, []) + conns.append(conn) + self.queue.append(conn) + + if len(conns) <= self.max_tasks and len(self.queue) <= self.max_pool: + return + + # Prune the queue. + + # Close stale connections for this key first. + stale = [conn for conn in conns if conn.stale()] + if stale: + for conn in stale: + conns.remove(conn) + self.queue.remove(conn) + self.log(1, 'closing stale connection for', key) + conn.close() + if not conns: + del self.connections[key] + + # Close oldest connection(s) for this key if limit reached. + while len(conns) > self.max_tasks: + conn = conns.pop(0) + self.queue.remove(conn) + self.log(1, 'closing oldest connection for', key) + conn.close() + + if len(self.queue) <= self.max_pool: + return + + # Close overall stale connections. + stale = [conn for conn in self.queue if conn.stale()] + if stale: + for conn in stale: + conns = self.connections.get(conn.key) + conns.remove(conn) + self.queue.remove(conn) + self.log(1, 'closing stale connection for', key) + conn.close() + + # Close oldest overall connection(s) if limit reached. + while len(self.queue) > self.max_pool: + conn = self.queue.pop(0) + conns = self.connections.get(conn.key) + c = conns.pop(0) + assert conn == c, (conn.key, conn, c, conns) + self.log(1, 'closing overall oldest connection for', conn.key) + conn.close() + + +class Connection: + + def __init__(self, log: Logger, pool: ConnectionPool, host: str, port: int, ssl: bool) -> None: + self.log = log + self.pool = pool + self.host = host + self.port = port + self.ssl = ssl + self.reader = None # type: asyncio.StreamReader + self.writer = None # type: asyncio.StreamWriter + self.key = None # type: KeyTuple + + def stale(self) -> bool: + return self.reader is None or self.reader.at_eof() + + def fileno(self) -> Optional[int]: + writer = self.writer + if writer is not None: + transport = writer.transport + if transport is not None: + sock = transport.get_extra_info('socket') + if sock is not None: + return sock.fileno() + return None + + async def connect(self) -> None: + self.reader, self.writer = await asyncio.open_connection( + self.host, self.port, ssl=self.ssl) + peername = self.writer.get_extra_info('peername') + if peername: + self.host, self.port = peername[:2] + else: + self.log(1, 'NO PEERNAME???', self.host, self.port, self.ssl) + self.key = self.host, self.port, self.ssl + + def close(self, recycle: bool = False) -> None: + if recycle and not self.stale(): + self.pool.recycle_connection(self) + else: + self.writer.close() + self.pool = self.reader = self.writer = None + + +class Request: + """HTTP request. + + Use connect() to open a connection; send_request() to send the + request; get_response() to receive the response headers. + """ + + def __init__(self, log: Logger, url: str, pool: ConnectionPool) -> None: + self.log = log + self.url = url + self.pool = pool + self.parts = urllib.parse.urlparse(self.url) + self.scheme = self.parts.scheme + assert self.scheme in ('http', 'https'), repr(url) + self.ssl = self.parts.scheme == 'https' + self.netloc = self.parts.netloc + self.hostname = self.parts.hostname + self.port = self.parts.port or (443 if self.ssl else 80) + self.path = (self.parts.path or '/') + self.query = self.parts.query + if self.query: + self.full_path = '%s?%s' % (self.path, self.query) + else: + self.full_path = self.path + self.http_version = 'HTTP/1.1' + self.method = 'GET' + self.headers = [] # type: List[Tuple[str, str]] + self.conn = None # type: Connection + + async def connect(self) -> None: + """Open a connection to the server.""" + self.log(1, '* Connecting to %s:%s using %s for %s' % + (self.hostname, self.port, + 'ssl' if self.ssl else 'tcp', + self.url)) + self.conn = await self.pool.get_connection(self.hostname, + self.port, self.ssl) + + def close(self, recycle: bool = False) -> None: + """Close the connection, recycle if requested.""" + if self.conn is not None: + if not recycle: + self.log(1, 'closing connection for', self.conn.key) + self.conn.close(recycle) + self.conn = None + + async def putline(self, line: str) -> None: + """Write a line to the connection. + + Used for the request line and headers. + """ + self.log(2, '>', line) + self.conn.writer.write(line.encode('latin-1') + b'\r\n') + + async def send_request(self) -> None: + """Send the request.""" + request_line = '%s %s %s' % (self.method, self.full_path, + self.http_version) + await self.putline(request_line) + # TODO: What if a header is already set? + self.headers.append(('User-Agent', 'asyncio-example-crawl/0.0')) + self.headers.append(('Host', self.netloc)) + self.headers.append(('Accept', '*/*')) + # self.headers.append(('Accept-Encoding', 'gzip')) + for key, value in self.headers: + line = '%s: %s' % (key, value) + await self.putline(line) + await self.putline('') + + async def get_response(self) -> 'Response': + """Receive the response.""" + response = Response(self.log, self.conn.reader) + await response.read_headers() + return response + + +class Response: + """HTTP response. + + Call read_headers() to receive the request headers. Then check + the status attribute and call get_header() to inspect the headers. + Finally call read() to receive the body. + """ + + def __init__(self, log: Logger, reader: asyncio.StreamReader) -> None: + self.log = log + self.reader = reader + self.http_version = None # type: str # 'HTTP/1.1' + self.status = None # type: int # 200 + self.reason = None # type: str # 'Ok' + self.headers = [] # type: List[Tuple[str, str]] # [('Content-Type', 'text/html')] + + async def getline(self) -> str: + """Read one line from the connection.""" + line = (await self.reader.readline()).decode('latin-1').rstrip() + self.log(2, '<', line) + return line + + async def read_headers(self) -> None: + """Read the response status and the request headers.""" + status_line = await self.getline() + status_parts = status_line.split(None, 2) + if len(status_parts) != 3: + self.log(0, 'bad status_line', repr(status_line)) + raise BadStatusLine(status_line) + self.http_version, status, self.reason = status_parts + self.status = int(status) + while True: + header_line = await self.getline() + if not header_line: + break + # TODO: Continuation lines. + key, value = header_line.split(':', 1) + self.headers.append((key, value.strip())) + + def get_redirect_url(self, default: str = '') -> str: + """Inspect the status and return the redirect url if appropriate.""" + if self.status not in (300, 301, 302, 303, 307): + return default + return self.get_header('Location', default) + + def get_header(self, key: str, default: str = '') -> str: + """Get one header value, using a case insensitive header name.""" + key = key.lower() + for k, v in self.headers: + if k.lower() == key: + return v + return default + + async def read(self) -> bytes: + """Read the response body. + + This honors Content-Length and Transfer-Encoding: chunked. + """ + nbytes = None + for key, value in self.headers: + if key.lower() == 'content-length': + nbytes = int(value) + break + if nbytes is None: + if self.get_header('transfer-encoding').lower() == 'chunked': + self.log(2, 'parsing chunked response') + blocks = [] + while True: + size_header = await self.reader.readline() + if not size_header: + self.log(0, 'premature end of chunked response') + break + self.log(3, 'size_header =', repr(size_header)) + parts = size_header.split(b';') + size = int(parts[0], 16) + if size: + self.log(3, 'reading chunk of', size, 'bytes') + block = await self.reader.readexactly(size) + assert len(block) == size, (len(block), size) + blocks.append(block) + crlf = await self.reader.readline() + assert crlf == b'\r\n', repr(crlf) + if not size: + break + body = b''.join(blocks) + self.log(1, 'chunked response had', len(body), + 'bytes in', len(blocks), 'blocks') + else: + self.log(3, 'reading until EOF') + body = await self.reader.read() + # TODO: Should make sure not to recycle the connection + # in this case. + else: + body = await self.reader.readexactly(nbytes) + return body + + +class Fetcher: + """Logic and state for one URL. + + When found in crawler.busy, this represents a URL to be fetched or + in the process of being fetched; when found in crawler.done, this + holds the results from fetching it. + + This is usually associated with a task. This references the + crawler for the connection pool and to add more URLs to its todo + list. + + Call fetch() to do the fetching, then report() to print the results. + """ + + def __init__(self, log: Logger, url: str, crawler: 'Crawler', + max_redirect: int = 10, max_tries: int = 4) -> None: + self.log = log + self.url = url + self.crawler = crawler + # We don't loop resolving redirects here -- we just use this + # to decide whether to add the redirect URL to crawler.todo. + self.max_redirect = max_redirect + # But we do loop to retry on errors a few times. + self.max_tries = max_tries + # Everything we collect from the response goes here. + self.task = None # type: asyncio.Task + self.exceptions = [] # type: List[Exception] + self.tries = 0 + self.request = None # type: Request + self.response = None # type: Response + self.body = None # type: bytes + self.next_url = None # type: str + self.ctype = None # type: str + self.pdict = None # type: Dict[str, str] + self.encoding = None # type: str + self.urls = None # type: Set[str] + self.new_urls = None # type: Set[str] + + async def fetch(self) -> None: + """Attempt to fetch the contents of the URL. + + If successful, and the data is HTML, extract further links and + add them to the crawler. Redirects are also added back there. + """ + while self.tries < self.max_tries: + self.tries += 1 + self.request = None + try: + self.request = Request(self.log, self.url, self.crawler.pool) + await self.request.connect() + await self.request.send_request() + self.response = await self.request.get_response() + self.body = await self.response.read() + h_conn = self.response.get_header('connection').lower() + if h_conn != 'close': + self.request.close(recycle=True) + self.request = None + if self.tries > 1: + self.log(1, 'try', self.tries, 'for', self.url, 'success') + break + except (BadStatusLine, OSError) as exc: + self.exceptions.append(exc) + self.log(1, 'try', self.tries, 'for', self.url, + 'raised', repr(exc)) + # import pdb; pdb.set_trace() + # Don't reuse the connection in this case. + finally: + if self.request is not None: + self.request.close() + else: + # We never broke out of the while loop, i.e. all tries failed. + self.log(0, 'no success for', self.url, + 'in', self.max_tries, 'tries') + return + next_url = self.response.get_redirect_url() + if next_url: + self.next_url = urllib.parse.urljoin(self.url, next_url) + if self.max_redirect > 0: + self.log(1, 'redirect to', self.next_url, 'from', self.url) + self.crawler.add_url(self.next_url, self.max_redirect - 1) + else: + self.log(0, 'redirect limit reached for', self.next_url, + 'from', self.url) + else: + if self.response.status == 200: + self.ctype = self.response.get_header('content-type') + self.pdict = {} + if self.ctype: + self.ctype, self.pdict = cgi.parse_header(self.ctype) + self.encoding = self.pdict.get('charset', 'utf-8') + if self.ctype == 'text/html': + body = self.body.decode(self.encoding, 'replace') + # Replace href with (?:href|src) to follow image links. + self.urls = set(re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', + body)) + if self.urls: + self.log(1, 'got', len(self.urls), + 'distinct urls from', self.url) + self.new_urls = set() + for url in self.urls: + url = unescape(url) + url = urllib.parse.urljoin(self.url, url) + url, frag = urllib.parse.urldefrag(url) + if self.crawler.add_url(url): + self.new_urls.add(url) + + def report(self, stats: 'Stats', file: IO[str] = None) -> None: + """Print a report on the state for this URL. + + Also update the Stats instance. + """ + if self.task is not None: + if not self.task.done(): + stats.add('pending') + print(self.url, 'pending', file=file) + return + elif self.task.cancelled(): + stats.add('cancelled') + print(self.url, 'cancelled', file=file) + return + elif self.task.exception(): + stats.add('exception') + exc = self.task.exception() + stats.add('exception_' + exc.__class__.__name__) + print(self.url, exc, file=file) + return + if len(self.exceptions) == self.tries: + stats.add('fail') + exc = self.exceptions[-1] + stats.add('fail_' + str(exc.__class__.__name__)) + print(self.url, 'error', exc, file=file) + elif self.next_url: + stats.add('redirect') + print(self.url, self.response.status, 'redirect', self.next_url, + file=file) + elif self.ctype == 'text/html': + stats.add('html') + size = len(self.body or b'') + stats.add('html_bytes', size) + if self.log.level: + print(self.url, self.response.status, + self.ctype, self.encoding, + size, + '%d/%d' % (len(self.new_urls or ()), len(self.urls or ())), + file=file) + elif self.response is None: + print(self.url, 'no response object') + else: + size = len(self.body or b'') + if self.response.status == 200: + stats.add('other') + stats.add('other_bytes', size) + else: + stats.add('error') + stats.add('error_bytes', size) + stats.add('status_%s' % self.response.status) + print(self.url, self.response.status, + self.ctype, self.encoding, + size, + file=file) + + +class Stats: + """Record stats of various sorts.""" + + def __init__(self) -> None: + self.stats = {} # type: Dict[str, int] + + def add(self, key: str, count: int = 1) -> None: + self.stats[key] = self.stats.get(key, 0) + count + + def report(self, file: IO[str] = None) -> None: + for key, count in sorted(self.stats.items()): + print('%10d' % count, key, file=file) + + +class Crawler: + """Crawl a set of URLs. + + This manages three disjoint sets of URLs (todo, busy, done). The + data structures actually store dicts -- the values in todo give + the redirect limit, while the values in busy and done are Fetcher + instances. + """ + def __init__(self, log: Logger, + roots: Set[str], exclude: str = None, strict: bool = True, # What to crawl. + max_redirect: int = 10, max_tries: int = 4, # Per-url limits. + max_tasks: int = 10, max_pool: int = 10, # Global limits. + ) -> None: + self.log = log + self.roots = roots + self.exclude = exclude + self.strict = strict + self.max_redirect = max_redirect + self.max_tries = max_tries + self.max_tasks = max_tasks + self.max_pool = max_pool + self.todo = {} # type: Dict[str, int] + self.busy = {} # type: Dict[str, Fetcher] + self.done = {} # type: Dict[str, Fetcher] + self.pool = ConnectionPool(self.log, max_pool, max_tasks) + self.root_domains = set() # type: Set[str] + for root in roots: + host = urllib.parse.urlparse(root).hostname + if not host: + continue + if re.match(r'\A[\d\.]*\Z', host): + self.root_domains.add(host) + else: + host = host.lower() + if self.strict: + self.root_domains.add(host) + if host.startswith('www.'): + self.root_domains.add(host[4:]) + else: + self.root_domains.add('www.' + host) + else: + parts = host.split('.') + if len(parts) > 2: + host = '.'.join(parts[-2:]) + self.root_domains.add(host) + for root in roots: + self.add_url(root) + self.governor = asyncio.Semaphore(max_tasks) + self.termination = asyncio.Condition() + self.t0 = time.time() + self.t1 = None # type: Optional[float] + + def close(self) -> None: + """Close resources (currently only the pool).""" + self.pool.close() + + def host_okay(self, host: str) -> bool: + """Check if a host should be crawled. + + A literal match (after lowercasing) is always good. For hosts + that don't look like IP addresses, some approximate matches + are okay depending on the strict flag. + """ + host = host.lower() + if host in self.root_domains: + return True + if re.match(r'\A[\d\.]*\Z', host): + return False + if self.strict: + return self._host_okay_strictish(host) + else: + return self._host_okay_lenient(host) + + def _host_okay_strictish(self, host: str) -> bool: + """Check if a host should be crawled, strict-ish version. + + This checks for equality modulo an initial 'www.' component. + """ + if host.startswith('www.'): + if host[4:] in self.root_domains: + return True + else: + if 'www.' + host in self.root_domains: + return True + return False + + def _host_okay_lenient(self, host: str) -> bool: + """Check if a host should be crawled, lenient version. + + This compares the last two components of the host. + """ + parts = host.split('.') + if len(parts) > 2: + host = '.'.join(parts[-2:]) + return host in self.root_domains + + def add_url(self, url: str, max_redirect: int = None) -> bool: + """Add a URL to the todo list if not seen before.""" + if self.exclude and re.search(self.exclude, url): + return False + parsed = urllib.parse.urlparse(url) + if parsed.scheme not in ('http', 'https'): + self.log(2, 'skipping non-http scheme in', url) + return False + host = parsed.hostname + if not self.host_okay(host): + self.log(2, 'skipping non-root host in', url) + return False + if max_redirect is None: + max_redirect = self.max_redirect + if url in self.todo or url in self.busy or url in self.done: + return False + self.log(1, 'adding', url, max_redirect) + self.todo[url] = max_redirect + return True + + async def crawl(self) -> None: + """Run the crawler until all finished.""" + with (await self.termination): + while self.todo or self.busy: + if self.todo: + url, max_redirect = self.todo.popitem() + fetcher = Fetcher(self.log, url, + crawler=self, + max_redirect=max_redirect, + max_tries=self.max_tries, + ) + self.busy[url] = fetcher + fetcher.task = asyncio.Task(self.fetch(fetcher)) + else: + await self.termination.wait() + self.t1 = time.time() + + async def fetch(self, fetcher: Fetcher) -> None: + """Call the Fetcher's fetch(), with a limit on concurrency. + + Once this returns, move the fetcher from busy to done. + """ + url = fetcher.url + with (await self.governor): + try: + await fetcher.fetch() # Fetcher gonna fetch. + finally: + # Force GC of the task, so the error is logged. + fetcher.task = None + with (await self.termination): + self.done[url] = fetcher + del self.busy[url] + self.termination.notify() + + def report(self, file: IO[str] = None) -> None: + """Print a report on all completed URLs.""" + if self.t1 is None: + self.t1 = time.time() + dt = self.t1 - self.t0 + if dt and self.max_tasks: + speed = len(self.done) / dt / self.max_tasks + else: + speed = 0 + stats = Stats() + print('*** Report ***', file=file) + try: + show = [] # type: List[Tuple[str, Fetcher]] + show.extend(self.done.items()) + show.extend(self.busy.items()) + show.sort() + for url, fetcher in show: + fetcher.report(stats, file=file) + except KeyboardInterrupt: + print('\nInterrupted', file=file) + print('Finished', len(self.done), + 'urls in %.3f secs' % dt, + '(max_tasks=%d)' % self.max_tasks, + '(%.3f urls/sec/task)' % speed, + file=file) + stats.report(file=file) + print('Todo:', len(self.todo), file=file) + print('Busy:', len(self.busy), file=file) + print('Done:', len(self.done), file=file) + print('Date:', time.ctime(), 'local time', file=file) + + +def main() -> None: + """Main program. + + Parse arguments, set up event loop, run crawler, print report. + """ + args = ARGS.parse_args() + if not args.roots: + print('Use --help for command line help') + return + + log = Logger(args.level) + + if args.iocp: + if sys.platform == 'win32': + from asyncio import ProactorEventLoop + loop = ProactorEventLoop() # type: ignore + asyncio.set_event_loop(loop) + else: + assert False + elif args.select: + loop = asyncio.SelectorEventLoop() # type: ignore + asyncio.set_event_loop(loop) + else: + loop = asyncio.get_event_loop() + + roots = {fix_url(root) for root in args.roots} + + crawler = Crawler(log, + roots, exclude=args.exclude, + strict=args.strict, + max_redirect=args.max_redirect, + max_tries=args.max_tries, + max_tasks=args.max_tasks, + max_pool=args.max_pool, + ) + try: + loop.run_until_complete(crawler.crawl()) # Crawler gonna crawl. + except KeyboardInterrupt: + sys.stderr.flush() + print('\nInterrupted\n') + finally: + crawler.report() + crawler.close() + loop.close() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + main() diff --git a/test-data/unit/check-async-await.test b/test-data/unit/check-async-await.test index 581e35d99957..001d5218acd2 100644 --- a/test-data/unit/check-async-await.test +++ b/test-data/unit/check-async-await.test @@ -293,3 +293,109 @@ async def main() -> None: [builtins fixtures/async_await.py] [out] main: note: In function "main": + +[case testYieldTypeCheckInDecoratedCoroutine] +# options: fast_parser +from typing import Generator +from types import coroutine +@coroutine +def f() -> Generator[int, str, int]: + x = yield 0 + x = yield '' # E: Incompatible types in yield (actual type "str", expected type "int") + reveal_type(x) # E: Revealed type is 'builtins.str' + if x: + return 0 + else: + return '' # E: Incompatible return value type (got "str", expected "int") +[builtins fixtures/async_await.py] +[out] +main: note: In function "f": + + +-- The full matrix of coroutine compatibility +-- ------------------------------------------ + +[case testFullCoroutineMatrix] +# options: fast_parser suppress_error_context +from typing import Any, AsyncIterator, Awaitable, Generator, Iterator +from types import coroutine + +# The various things you might try to use in `await` or `yield from`. + +def plain_generator() -> Generator[str, None, int]: + yield 'a' + return 1 + +async def plain_coroutine() -> int: + return 1 + +@coroutine +def decorated_generator() -> Generator[str, None, int]: + yield 'a' + return 1 + +@coroutine +async def decorated_coroutine() -> int: + return 1 + +class It(Iterator[str]): + def __iter__(self) -> 'It': + return self + def __next__(self) -> str: + return 'a' + +def other_iterator() -> It: + return It() + +class Aw(Awaitable[int]): + def __await__(self) -> Generator[str, Any, int]: + yield 'a' + return 1 + +def other_coroutine() -> Aw: + return Aw() + +# The various contexts in which `await` or `yield from` might occur. + +def plain_host_generator() -> Generator[str, None, None]: + yield 'a' + x = 0 + x = yield from plain_generator() + x = yield from plain_coroutine() # E: "yield from" can't be applied to Awaitable[int] + x = yield from decorated_generator() + x = yield from decorated_coroutine() # E: "yield from" can't be applied to AwaitableGenerator[...] + x = yield from other_iterator() + x = yield from other_coroutine() # E: "yield from" can't be applied to "Aw" + +async def plain_host_coroutine() -> None: + x = 0 + x = await plain_generator() # E: Incompatible types in await (actual type Generator[str, None, int], expected type "Awaitable") + x = await plain_coroutine() + x = await decorated_generator() + x = await decorated_coroutine() + x = await other_iterator() # E: Incompatible types in await (actual type "It", expected type "Awaitable") + x = await other_coroutine() + +@coroutine +def decorated_host_generator() -> Generator[str, None, None]: + yield 'a' + x = 0 + x = yield from plain_generator() + x = yield from plain_coroutine() + x = yield from decorated_generator() + x = yield from decorated_coroutine() + x = yield from other_iterator() + x = yield from other_coroutine() # E: "yield from" can't be applied to "Aw" + +@coroutine +async def decorated_host_coroutine() -> None: + x = 0 + x = await plain_generator() # E: Incompatible types in await (actual type Generator[str, None, int], expected type "Awaitable") + x = await plain_coroutine() + x = await decorated_generator() + x = await decorated_coroutine() + x = await other_iterator() # E: Incompatible types in await (actual type "It", expected type "Awaitable") + x = await other_coroutine() + +[builtins fixtures/async_await.py] +[out] diff --git a/test-data/unit/check-statements.test b/test-data/unit/check-statements.test index 2812eeaadd5f..fa19fcf3f9b3 100644 --- a/test-data/unit/check-statements.test +++ b/test-data/unit/check-statements.test @@ -88,6 +88,7 @@ def f() -> Iterator[int]: return "foo" [out] + -- If statement -- ------------ diff --git a/test-data/unit/lib-stub/types.py b/test-data/unit/lib-stub/types.py new file mode 100644 index 000000000000..aa0a19fc99c2 --- /dev/null +++ b/test-data/unit/lib-stub/types.py @@ -0,0 +1,4 @@ +from typing import TypeVar +T = TypeVar('T') +def coroutine(func: T) -> T: + return func diff --git a/test-data/unit/lib-stub/typing.py b/test-data/unit/lib-stub/typing.py index 3ba9a4398c8a..73abb3d3ebf4 100644 --- a/test-data/unit/lib-stub/typing.py +++ b/test-data/unit/lib-stub/typing.py @@ -26,6 +26,7 @@ T = TypeVar('T') U = TypeVar('U') V = TypeVar('V') +S = TypeVar('S') class Container(Generic[T]): @abstractmethod @@ -61,6 +62,9 @@ class Awaitable(Generic[T]): @abstractmethod def __await__(self) -> Generator[Any, Any, T]: pass +class AwaitableGenerator(Generator[T, U, V], Awaitable[V], Generic[T, U, V, S]): + pass + class AsyncIterable(Generic[T]): @abstractmethod def __aiter__(self) -> 'AsyncIterator[T]': pass diff --git a/typeshed b/typeshed index 53f0ed7e689d..57e48f31383c 160000 --- a/typeshed +++ b/typeshed @@ -1 +1 @@ -Subproject commit 53f0ed7e689d7e59da12c2241bdecde8514333ab +Subproject commit 57e48f31383c47f46bfb2675883a81862a397687