Skip to content

Refactor SSL shutdown process #385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 143 additions & 54 deletions tests/test_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2609,14 +2609,18 @@ async def client(addr):

def test_remote_shutdown_receives_trailing_data(self):
if self.implementation == 'asyncio':
# this is an issue in asyncio
raise unittest.SkipTest()

CHUNK = 1024 * 128
SIZE = 32
CHUNK = 1024 * 16
SIZE = 8
count = 0

sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
client_sslctx = self._create_client_ssl_context()
future = None
filled = threading.Lock()
eof_received = threading.Lock()

def server(sock):
incoming = ssl.MemoryBIO()
Expand Down Expand Up @@ -2647,68 +2651,71 @@ def server(sock):
sslobj.write(b'pong')
sock.send(outgoing.read())

time.sleep(0.2) # wait for the peer to fill its backlog

# send close_notify but don't wait for response
with self.assertRaises(ssl.SSLWantReadError):
sslobj.unwrap()
sock.send(outgoing.read())

# should receive all data
data_len = 0
while True:
try:
chunk = len(sslobj.read(16384))
data_len += chunk
except ssl.SSLWantReadError:
incoming.write(sock.recv(16384))
except ssl.SSLZeroReturnError:
break

self.assertEqual(data_len, CHUNK * SIZE)

# verify that close_notify is received
sslobj.unwrap()

sock.close()
with filled:
# trigger peer's resume_writing()
incoming.write(sock.recv(65536 * 4))
while True:
try:
chunk = len(sslobj.read(16384))
data_len += chunk
except ssl.SSLWantReadError:
break

def eof_server(sock):
sock.starttls(sslctx, server_side=True)
self.assertEqual(sock.recv_all(4), b'ping')
sock.send(b'pong')
# send close_notify but don't wait for response
with self.assertRaises(ssl.SSLWantReadError):
sslobj.unwrap()
sock.send(outgoing.read())

time.sleep(0.2) # wait for the peer to fill its backlog
with eof_received:
# should receive all data
while True:
try:
chunk = len(sslobj.read(16384))
data_len += chunk
except ssl.SSLWantReadError:
incoming.write(sock.recv(16384))
except ssl.SSLZeroReturnError:
break

# send EOF
sock.shutdown(socket.SHUT_WR)
self.assertEqual(data_len, CHUNK * count)

# should receive all data
data = sock.recv_all(CHUNK * SIZE)
self.assertEqual(len(data), CHUNK * SIZE)
# verify that close_notify is received
sslobj.unwrap()

sock.close()

async def client(addr):
nonlocal future
nonlocal future, count
future = self.loop.create_future()

reader, writer = await asyncio.open_connection(
*addr,
ssl=client_sslctx,
server_hostname='')
writer.write(b'ping')
data = await reader.readexactly(4)
self.assertEqual(data, b'pong')

# fill write backlog in a hacky way - renegotiation won't help
for _ in range(SIZE):
writer.transport._test__append_write_backlog(b'x' * CHUNK)
with eof_received:
with filled:
reader, writer = await asyncio.open_connection(
*addr,
ssl=client_sslctx,
server_hostname='')
writer.write(b'ping')
data = await reader.readexactly(4)
self.assertEqual(data, b'pong')

count = 0
try:
while True:
writer.write(b'x' * CHUNK)
count += 1
await asyncio.wait_for(
asyncio.ensure_future(writer.drain()), 0.5)
except asyncio.TimeoutError:
# fill write backlog in a hacky way
for _ in range(SIZE):
writer.transport._test__append_write_backlog(
b'x' * CHUNK)
count += 1

try:
data = await reader.read()
self.assertEqual(data, b'')
except (BrokenPipeError, ConnectionResetError):
pass

await future

Expand All @@ -2728,9 +2735,6 @@ def wrapper(sock):
with self.tcp_server(run(server)) as srv:
self.loop.run_until_complete(client(srv.addr))

with self.tcp_server(run(eof_server)) as srv:
self.loop.run_until_complete(client(srv.addr))

def test_connect_timeout_warning(self):
s = socket.socket(socket.AF_INET)
s.bind(('127.0.0.1', 0))
Expand Down Expand Up @@ -2842,7 +2846,7 @@ def server(sock):
sock.shutdown(socket.SHUT_WR)
loop.call_soon_threadsafe(eof.set)
# make sure we have enough time to reproduce the issue
assert sock.recv(1024) == b''
self.assertEqual(sock.recv(1024), b'')
sock.close()

class Protocol(asyncio.Protocol):
Expand Down Expand Up @@ -2875,7 +2879,92 @@ async def client(addr):
tr.resume_reading()
await pr.fut
tr.close()
assert extra == b'extra bytes'
if self.implementation != 'asyncio':
# extra data received after transport.close() should be
# ignored - this is likely a bug in asyncio
self.assertIsNone(extra)

with self.tcp_server(server) as srv:
loop.run_until_complete(client(srv.addr))

def test_shutdown_while_pause_reading(self):
if self.implementation == 'asyncio':
raise unittest.SkipTest()

loop = self.loop
conn_made = loop.create_future()
eof_recvd = loop.create_future()
conn_lost = loop.create_future()
data_recv = False

def server(sock):
sslctx = self._create_server_ssl_context(self.ONLYCERT,
self.ONLYKEY)
incoming = ssl.MemoryBIO()
outgoing = ssl.MemoryBIO()
sslobj = sslctx.wrap_bio(incoming, outgoing, server_side=True)

while True:
try:
sslobj.do_handshake()
sslobj.write(b'trailing data')
break
except ssl.SSLWantReadError:
if outgoing.pending:
sock.send(outgoing.read())
incoming.write(sock.recv(16384))
if outgoing.pending:
sock.send(outgoing.read())

while True:
try:
self.assertEqual(sslobj.read(), b'') # close_notify
break
except ssl.SSLWantReadError:
incoming.write(sock.recv(16384))

while True:
try:
sslobj.unwrap()
except ssl.SSLWantReadError:
if outgoing.pending:
sock.send(outgoing.read())
# incoming.write(sock.recv(16384))
else:
if outgoing.pending:
sock.send(outgoing.read())
break

self.assertEqual(sock.recv(16384), b'') # socket closed

class Protocol(asyncio.Protocol):
def connection_made(self, transport):
conn_made.set_result(None)

def data_received(self, data):
nonlocal data_recv
data_recv = True

def eof_received(self):
eof_recvd.set_result(None)

def connection_lost(self, exc):
if exc is None:
conn_lost.set_result(None)
else:
conn_lost.set_exception(exc)

async def client(addr):
ctx = self._create_client_ssl_context()
tr, _ = await loop.create_connection(Protocol, *addr, ssl=ctx)
await conn_made
self.assertFalse(data_recv)

tr.pause_reading()
tr.close()

await eof_recvd
await conn_lost

with self.tcp_server(server) as srv:
loop.run_until_complete(client(srv.addr))
Expand Down
1 change: 1 addition & 0 deletions uvloop/includes/stdlib.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ cdef ssl_MemoryBIO = ssl.MemoryBIO
cdef ssl_create_default_context = ssl.create_default_context
cdef ssl_SSLError = ssl.SSLError
cdef ssl_SSLAgainErrors = (ssl.SSLWantReadError, ssl.SSLSyscallError)
cdef ssl_SSLZeroReturnError = ssl.SSLZeroReturnError
cdef ssl_CertificateError = ssl.CertificateError
cdef int ssl_SSL_ERROR_WANT_READ = ssl.SSL_ERROR_WANT_READ
cdef int ssl_SSL_ERROR_WANT_WRITE = ssl.SSL_ERROR_WANT_WRITE
Expand Down
6 changes: 3 additions & 3 deletions uvloop/sslproto.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ cdef enum AppProtocolState:

cdef class _SSLProtocolTransport:
cdef:
object _loop
Loop _loop
SSLProtocol _ssl_protocol
bint _closed

Expand All @@ -41,7 +41,7 @@ cdef class SSLProtocol:
size_t _write_buffer_size

object _waiter
object _loop
Loop _loop
_SSLProtocolTransport _app_transport
bint _app_transport_created

Expand All @@ -65,7 +65,6 @@ cdef class SSLProtocol:

bint _ssl_writing_paused
bint _app_reading_paused
bint _eof_received

size_t _incoming_high_water
size_t _incoming_low_water
Expand Down Expand Up @@ -100,6 +99,7 @@ cdef class SSLProtocol:

cdef _start_shutdown(self)
cdef _check_shutdown_timeout(self)
cdef _do_read_into_void(self)
cdef _do_flush(self)
cdef _do_shutdown(self)
cdef _on_shutdown_complete(self, shutdown_exc)
Expand Down
Loading