Skip to content
This repository was archived by the owner on Apr 14, 2022. It is now read-only.

Send request through trio backend #3

Merged
merged 5 commits into from
Jan 2, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions trio_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import trio
import urllib3
from urllib3.backends.trio_backend import TrioBackend


async def main():
http = urllib3.PoolManager(TrioBackend())
r = await http.request('GET', 'http://httpbin.org/robots.txt', preload_content=False)
print(r.status) # prints "200"
print(await r.read()) # prints "User-agent: *\nDisallow: /deny\n"

trio.run(main)
1 change: 1 addition & 0 deletions urllib3/backends/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
else:
DEFAULT_SELECTOR = selectors.PollSelector


def is_readable(sock):
s = DEFAULT_SELECTOR()
s.register(sock, selectors.EVENT_READ)
Expand Down
6 changes: 5 additions & 1 deletion urllib3/backends/sync_backend.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import errno
import select
import socket
import ssl
from ..util.connection import create_connection
from ..util.ssl_ import ssl_wrap_socket
Expand All @@ -10,6 +13,7 @@

BUFSIZE = 65536


class SyncBackend(object):
def __init__(self, connect_timeout, read_timeout):
self._connect_timeout = connect_timeout
Expand Down Expand Up @@ -72,7 +76,7 @@ async def receive_some(self):
else:
raise

async def send_and_receive_for_a_while(produce_bytes, consume_bytes):
async def send_and_receive_for_a_while(self, produce_bytes, consume_bytes):
outgoing_finished = False
outgoing = b""
try:
Expand Down
19 changes: 12 additions & 7 deletions urllib3/backends/trio_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from . import LoopAbort
from ._util import is_readable

BUFSIZE = 65536


class TrioBackend:
async def connect(
self, host, port, source_address=None, socket_options=None):
Expand All @@ -23,6 +26,8 @@ async def connect(
# cancellation, but we probably should do something to detect when the stream
# has been broken by cancellation (e.g. a timeout) and make is_readable return
# True so the connection won't be reused.


class TrioSocket:
def __init__(self, stream):
self._stream = stream
Expand All @@ -40,7 +45,7 @@ def getpeercert(self, binary=False):
async def receive_some(self):
return await self._stream.receive_some(BUFSIZE)

async def send_and_receive_for_a_while(produce_bytes, consume_bytes):
async def send_and_receive_for_a_while(self, produce_bytes, consume_bytes):
async def sender():
while True:
outgoing = await produce_bytes()
Expand All @@ -50,18 +55,18 @@ async def sender():

async def receiver():
while True:
incoming = await stream.receive_some(BUFSIZE)
incoming = await self._stream.receive_some(BUFSIZE)
consume_bytes(incoming)

try:
async with trio.open_nursery() as nursery:
nursery.spawn(sender)
nursery.spawn(receiver)
nursery.start_soon(sender)
nursery.start_soon(receiver)
except LoopAbort:
pass

def forceful_close(self):
self._stream.forceful_close()
async def forceful_close(self):
await trio.aclose_forcefully(self._stream)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh, forceful close raises such tricky questions for API design.

I started to write a long thing here about why close should maybe be synchronous after all. Trio's abstract stream interface makes it async because it's, y'know, abstract, and some streams need it to be async. But for the particular concrete streams that urllib3 uses, it's always synchronous. So maybe, I was thinking, it should be synchronous here too, and we'd make the trio backend jump through the necessary hoops to make that work.

Then l looked at twisted and asyncio and realized that they also make closing a socket an async operation for weird API reasons, so fine, let's just make it async :-). It's an issue we might want to revisit in the future though as this develops. (I'm also inclined to say that if we're going to make closing async, we should rename it aclose, which is the convention that trio uses; I stole it from the interpreter's async generator API. But that's also something we can worry about later.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. So forceful_close should become aclose_forcefully?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I guess it could. That doesn't matter too much though, because the backend API is internal anyway. The more important question is whether we want to rename Response.close, PoolManager.close, etc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'll prepare another PR for this change.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just opened #4.


def is_readable(self):
# This is a bit of a hack, but I can't think of a better API that trio
Expand All @@ -73,6 +78,6 @@ def is_readable(self):
sock_stream = sock_stream.transport_stream
sock = sock_stream.socket
return is_readable(sock)

def set_readable_watch_state(self, enabled):
pass
17 changes: 13 additions & 4 deletions urllib3/backends/twisted_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import socket
import OpenSSL.crypto
from twisted.internet import protocol
from twisted.internet import protocol, ssl
from twisted.internet.interfaces import IHandshakeListener
from twisted.internet.endpoints import HostnameEndpoint, connectProtocol
from twisted.internet.defer import (
Deferred, DeferredList, CancelledError, ensureDeferred,
Expand All @@ -11,6 +12,8 @@
from . import LoopAbort

# XX need to add timeout support, esp. on connect


class TwistedBackend:
def __init__(self, reactor):
self._reactor = reactor
Expand All @@ -21,7 +24,7 @@ async def connect(self, host, port, source_address=None, socket_options=None):
raise NotImplementedError(
"twisted backend doesn't support setting source_address")

factory = protocol.Factory.forProtocol(TwistedSocketProtocol)
# factory = protocol.Factory.forProtocol(TwistedSocketProtocol)
endpoint = HostnameEndpoint(self._reactor, host, port)
d = connectProtocol(endpoint, TwistedSocketProtocol())
# XX d.addTimeout(...)
Expand All @@ -39,8 +42,12 @@ async def connect(self, host, port, source_address=None, socket_options=None):
# enums
class _DATA_RECEIVED:
pass


class _RESUME_PRODUCING:
pass


class _HANDSHAKE_COMPLETED:
pass

Expand Down Expand Up @@ -71,6 +78,7 @@ async def _wait_for(self, event):
d = Deferred()
# We might get callbacked, we might get cancelled; either way we want
# to clean up then pass through the result:

def cleanup(obj):
del self._events[event]
return obj
Expand Down Expand Up @@ -138,6 +146,7 @@ def set_readable_watch_state(self, enabled):
else:
self.transport.pauseProducing()


class DoubleError(Exception):
def __init__(self, exc1, exc2):
self.exc1 = exc1
Expand Down Expand Up @@ -173,7 +182,7 @@ def getpeercert(self, binary=False):
async def receive_some(self):
return await self._protocol.receive_some()

async def send_and_receive_for_a_while(produce_bytes, consume_bytes):
async def send_and_receive_for_a_while(self, produce_bytes, consume_bytes):
async def sender():
while True:
outgoing = await produce_bytes()
Expand Down Expand Up @@ -209,7 +218,7 @@ def receive_loop_allback(result):

# Wait for both to finish, and then figure out if we need to raise an
# exception.
results = await DeferredList([d1, d2])
results = await DeferredList([send_loop, receive_loop])
# First, find the failure objects - but since we've almost always
# cancelled one of the deferreds, which causes it to raise
# CancelledError, we can't treat these at face value.
Expand Down
54 changes: 29 additions & 25 deletions urllib3/connectionpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
# Return False to re-raise any potential exceptions
return False

def close(self):
async def close(self):
"""
Close all pooled connections and disable the pool.
"""
Expand Down Expand Up @@ -248,7 +248,7 @@ def _new_conn(self):
**self.conn_kw)
return conn

def _get_conn(self, timeout=None):
async def _get_conn(self, timeout=None):
"""
Get a connection. Will return a pooled connection if one is available.

Expand Down Expand Up @@ -277,11 +277,11 @@ def _get_conn(self, timeout=None):
# If this is a persistent connection, check if it got disconnected
if conn and is_connection_dropped(conn):
log.debug("Resetting dropped connection: %s", self.host)
conn.close()
await conn.close()

return conn or self._new_conn()

def _put_conn(self, conn):
async def _put_conn(self, conn):
"""
Put a connection back into the pool.

Expand Down Expand Up @@ -309,13 +309,13 @@ def _put_conn(self, conn):

# Connection never got put back into the pool, close it.
if conn:
conn.close()
await conn.close()

def _start_conn(self, conn, connect_timeout):
async def _start_conn(self, conn, connect_timeout):
"""
Called right before a request is made, after the socket is created.
"""
conn.connect(connect_timeout=connect_timeout)
await conn.connect(connect_timeout=connect_timeout)

def _get_timeout(self, timeout):
""" Helper that always returns a :class:`urllib3.util.Timeout` """
Expand Down Expand Up @@ -347,8 +347,8 @@ def _raise_timeout(self, err, url, timeout_value):
if 'timed out' in str(err) or 'did not complete (read)' in str(err): # Python 2.6
raise ReadTimeoutError(self, url, "Read timed out. (read timeout=%s)" % timeout_value)

def _make_request(self, conn, method, url, timeout=_Default, body=None,
headers=None):
async def _make_request(
self, conn, method, url, timeout=_Default, body=None, headers=None):
"""
Perform a request on a given urllib connection object taken from our
pool.
Expand All @@ -370,7 +370,7 @@ def _make_request(self, conn, method, url, timeout=_Default, body=None,

# Trigger any extra validation we need to do.
try:
self._start_conn(conn, timeout_obj.connect_timeout)
await self._start_conn(conn, timeout_obj.connect_timeout)
except (SocketTimeout, BaseSSLError) as e:
# Py2 raises this as a BaseSSLError, Py3 raises it as socket timeout.
self._raise_timeout(err=e, url=url, timeout_value=conn.timeout)
Expand Down Expand Up @@ -405,7 +405,8 @@ def _make_request(self, conn, method, url, timeout=_Default, body=None,

# Receive the response from the server
try:
response = conn.send_request(request, read_timeout=read_timeout)
response = await conn.send_request(
request, read_timeout=read_timeout)
except (SocketTimeout, BaseSSLError, SocketError) as e:
self._raise_timeout(err=e, url=url, timeout_value=read_timeout)
raise
Expand All @@ -420,7 +421,7 @@ def _make_request(self, conn, method, url, timeout=_Default, body=None,
def _absolute_url(self, path):
return Url(scheme=self.scheme, host=self.host, port=self.port, path=path).url

def close(self):
async def close(self):
"""
Close all pooled connections and disable the pool.
"""
Expand All @@ -431,7 +432,7 @@ def close(self):
while True:
conn = old_pool.get(block=False)
if conn:
conn.close()
await conn.close()

except queue.Empty:
pass # Done.
Expand All @@ -457,8 +458,9 @@ def is_same_host(self, url):

return (scheme, host, port) == (self.scheme, self.host, self.port)

def urlopen(self, method, url, body=None, headers=None, retries=None,
timeout=_Default, pool_timeout=None, body_pos=None, **response_kw):
async def urlopen(self, method, url, body=None, headers=None, retries=None,
timeout=_Default, pool_timeout=None, body_pos=None,
**response_kw):
"""
Get a connection from the pool and perform an HTTP request. This is the
lowest level call for making a request, so you'll need to specify all
Expand Down Expand Up @@ -554,14 +556,15 @@ def urlopen(self, method, url, body=None, headers=None, retries=None,
try:
# Request a connection from the queue.
timeout_obj = self._get_timeout(timeout)
conn = self._get_conn(timeout=pool_timeout)
conn = await self._get_conn(timeout=pool_timeout)

conn.timeout = timeout_obj.connect_timeout

# Make the request on the base connection object.
base_response = self._make_request(conn, method, url,
timeout=timeout_obj,
body=body, headers=headers)
base_response = await self._make_request(conn, method, url,
timeout=timeout_obj,
body=body,
headers=headers)

# Pass method to Response for length checking
response_kw['request_method'] = method
Expand Down Expand Up @@ -615,22 +618,23 @@ def urlopen(self, method, url, body=None, headers=None, retries=None,
# to throw the connection away unless explicitly told not to.
# Close the connection, set the variable to None, and make sure
# we put the None back in the pool to avoid leaking it.
conn = conn and conn.close()
conn = conn and await conn.close()
release_this_conn = True

if release_this_conn:
# Put the connection back to be reused. If the connection is
# expired then it will be None, which will get replaced with a
# fresh connection during _get_conn.
self._put_conn(conn)
await self._put_conn(conn)

if not conn:
# Try again
log.warning("Retrying (%r) after connection "
"broken by '%r': %s", retries, err, url)
return self.urlopen(method, url, body, headers, retries,
timeout=timeout, pool_timeout=pool_timeout,
body_pos=body_pos, **response_kw)
return await self.urlopen(method, url, body, headers, retries,
timeout=timeout,
pool_timeout=pool_timeout,
body_pos=body_pos, **response_kw)

# Check if we should retry the HTTP response.
has_retry_after = bool(response.getheader('Retry-After'))
Expand All @@ -646,7 +650,7 @@ def urlopen(self, method, url, body=None, headers=None, retries=None,
return response
retries.sleep(response)
log.debug("Retry: %s", url)
return self.urlopen(
return await self.urlopen(
method, url, body, headers,
retries=retries, timeout=timeout, pool_timeout=pool_timeout,
body_pos=body_pos, **response_kw)
Expand Down
11 changes: 6 additions & 5 deletions urllib3/poolmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class PoolManager(RequestMethods):

proxy = None

def __init__(self, num_pools=10, headers=None, **connection_pool_kw):
def __init__(self, backend=None, num_pools=10, headers=None, **connection_pool_kw):
RequestMethods.__init__(self, headers)
self.connection_pool_kw = connection_pool_kw
self.pools = RecentlyUsedContainer(num_pools,
Expand All @@ -150,6 +150,7 @@ def __init__(self, num_pools=10, headers=None, **connection_pool_kw):
# override them.
self.pool_classes_by_scheme = pool_classes_by_scheme
self.key_fn_by_scheme = key_fn_by_scheme.copy()
self.backend = backend

def __enter__(self):
return self
Expand Down Expand Up @@ -184,7 +185,7 @@ def _new_pool(self, scheme, host, port, request_context=None):
for kw in SSL_KEYWORDS:
request_context.pop(kw, None)

return pool_cls(host, port, **request_context)
return pool_cls(host, port, **request_context, backend=self.backend)

def clear(self):
"""
Expand Down Expand Up @@ -290,7 +291,7 @@ def _merge_pool_kwargs(self, override):
base_pool_kwargs[key] = value
return base_pool_kwargs

def urlopen(self, method, url, redirect=True, **kw):
async def urlopen(self, method, url, redirect=True, **kw):
"""
Same as :meth:`urllib3.connectionpool.HTTPConnectionPool.urlopen`
with redirect logic and only sends the request-uri portion of the
Expand All @@ -312,9 +313,9 @@ def urlopen(self, method, url, redirect=True, **kw):
kw['headers'] = self.headers

if self.proxy is not None and u.scheme == "http":
response = conn.urlopen(method, url, **kw)
response = await conn.urlopen(method, url, **kw)
else:
response = conn.urlopen(method, u.request_uri, **kw)
response = await conn.urlopen(method, u.request_uri, **kw)

redirect_location = redirect and response.get_redirect_location()
if not redirect_location:
Expand Down
Loading