From 023de1a6ddc804ca9638d344092ef1584aba7de9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sat, 15 Oct 2022 13:58:53 +0000 Subject: [PATCH 1/5] Remove buffering from asyncio SocketBuffer and rely on on the underlying StreamReader --- redis/asyncio/connection.py | 115 +++++++----------------------------- 1 file changed, 21 insertions(+), 94 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index bc0362e782..8e2a17b336 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -207,12 +207,7 @@ async def read_response( class SocketBuffer: - """Async-friendly re-impl of redis-py's SocketBuffer. - - TODO: We're currently passing through two buffers, - the asyncio.StreamReader and this. I imagine we can reduce the layers here - while maintaining compliance with prior art. - """ + """Async-friendly re-impl of redis-py's SocketBuffer.""" def __init__( self, @@ -220,110 +215,42 @@ def __init__( socket_read_size: int, ): self._stream: Optional[asyncio.StreamReader] = stream_reader - self.socket_read_size = socket_read_size - self._buffer: Optional[io.BytesIO] = io.BytesIO() - # number of bytes written to the buffer from the socket - self.bytes_written = 0 - # number of bytes read from the buffer - self.bytes_read = 0 - - @property - def length(self): - return self.bytes_written - self.bytes_read - - async def _read_from_socket(self, length: Optional[int] = None) -> bool: - buf = self._buffer - if buf is None or self._stream is None: - raise RedisError("Buffer is closed.") - buf.seek(self.bytes_written) - marker = 0 - - while True: - data = await self._stream.read(self.socket_read_size) - # an empty string indicates the server shutdown the socket - if isinstance(data, bytes) and len(data) == 0: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - buf.write(data) - data_length = len(data) - self.bytes_written += data_length - marker += data_length - - if length is not None and length > marker: - continue - return True async def can_read_destructive(self) -> bool: - if self.length: - return True + if self._stream is None: + raise RedisError("Buffer is closed.") try: async with async_timeout.timeout(0): - return await self._read_from_socket() + return await self._stream.read(1) except asyncio.TimeoutError: return False async def read(self, length: int) -> bytes: - length = length + 2 # make sure to read the \r\n terminator - # make sure we've read enough data from the socket - if length > self.length: - await self._read_from_socket(length - self.length) - - if self._buffer is None: + """ + Read `length` bytes of data. These are assumed to be followed + by a '\r\n' terminator which is subsequently discarded. + """ + if self._stream is None: raise RedisError("Buffer is closed.") - - self._buffer.seek(self.bytes_read) - data = self._buffer.read(length) - self.bytes_read += len(data) - - # purge the buffer when we've consumed it all so it doesn't - # grow forever - if self.bytes_read == self.bytes_written: - self.purge() - + try: + data = await self._stream.readexactly(length + 2) + except asyncio.IncompleteReadError as error: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error return data[:-2] async def readline(self) -> bytes: - buf = self._buffer - if buf is None: + """ + read an unknown number of bytes up to the next '\r\n' + line separator, which is discarded. + """ + if self._stream is None: raise RedisError("Buffer is closed.") - - buf.seek(self.bytes_read) - data = buf.readline() - while not data.endswith(SYM_CRLF): - # there's more data in the socket that we need - await self._read_from_socket() - buf.seek(self.bytes_read) - data = buf.readline() - - self.bytes_read += len(data) - - # purge the buffer when we've consumed it all so it doesn't - # grow forever - if self.bytes_read == self.bytes_written: - self.purge() - + data = await self._stream.readline() + if not data.endswith(b"\r\n"): + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) return data[:-2] - def purge(self): - if self._buffer is None: - raise RedisError("Buffer is closed.") - - self._buffer.seek(0) - self._buffer.truncate() - self.bytes_written = 0 - self.bytes_read = 0 - def close(self): - try: - self.purge() - self._buffer.close() - except Exception: - # issue #633 suggests the purge/close somehow raised a - # BadFileDescriptor error. Perhaps the client ran out of - # memory or something else? It's probably OK to ignore - # any error being raised from purge/close since we're - # removing the reference to the instance below. - pass - self._buffer = None self._stream = None From bf78ba55bafda42ccc05f9495cba626a27d10712 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sat, 15 Oct 2022 14:12:57 +0000 Subject: [PATCH 2/5] Skip the use of SocketBuffer in PythonParser --- redis/asyncio/connection.py | 47 +++++++++++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 8e2a17b336..de14e2e8af 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -184,7 +184,7 @@ def parse_error(self, response: str) -> ResponseError: """Parse an error response""" error_code = response.split(" ")[0] if error_code in self.EXCEPTION_CLASSES: - response = response[len(error_code) + 1 :] + response = response[len(error_code) + 1:] exception_class = self.EXCEPTION_CLASSES[error_code] if isinstance(exception_class, dict): exception_class = exception_class.get(response, ResponseError) @@ -269,27 +269,29 @@ def on_connect(self, connection: "Connection"): if self._stream is None: raise RedisError("Buffer is closed.") - self._buffer = SocketBuffer(self._stream, self._read_size) self.encoder = connection.encoder def on_disconnect(self): """Called when the stream disconnects""" if self._stream is not None: self._stream = None - if self._buffer is not None: - self._buffer.close() - self._buffer = None self.encoder = None - async def can_read_destructive(self): - return self._buffer and bool(await self._buffer.can_read_destructive()) + async def can_read_destructive(self) -> bool: + if self._stream is None: + raise RedisError("Buffer is closed.") + try: + async with async_timeout.timeout(0): + return await self._stream.read(1) + except asyncio.TimeoutError: + return False async def read_response( self, disable_decoding: bool = False ) -> Union[EncodableT, ResponseError, None]: - if not self._buffer or not self.encoder: + if not self._stream or not self.encoder: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - raw = await self._buffer.readline() + raw = await self._readline() if not raw: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) response: Any @@ -322,7 +324,7 @@ async def read_response( length = int(response) if length == -1: return None - response = await self._buffer.read(length) + response = await self._read(length) # multi-bulk response elif byte == b"*": length = int(response) @@ -335,6 +337,31 @@ async def read_response( response = self.encoder.decode(response) return response + async def _read(self, length: int) -> bytes: + """ + Read `length` bytes of data. These are assumed to be followed + by a '\r\n' terminator which is subsequently discarded. + """ + if self._stream is None: + raise RedisError("Buffer is closed.") + try: + data = await self._stream.readexactly(length + 2) + except asyncio.IncompleteReadError as error: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error + return data[:-2] + + async def _readline(self) -> bytes: + """ + read an unknown number of bytes up to the next '\r\n' + line separator, which is discarded. + """ + if self._stream is None: + raise RedisError("Buffer is closed.") + data = await self._stream.readline() + if not data.endswith(b"\r\n"): + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + return data[:-2] + class HiredisParser(BaseParser): """Parser class for connections using Hiredis""" From f852682f67d40671b3a994865769b454eda680a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sat, 15 Oct 2022 14:18:05 +0000 Subject: [PATCH 3/5] Remove SocketBuffer altogether --- redis/asyncio/connection.py | 51 +------------------------------------ 1 file changed, 1 insertion(+), 50 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index de14e2e8af..c37796ee59 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -141,7 +141,7 @@ def decode(self, value: EncodableT, force=False) -> EncodableT: class BaseParser: """Plain Python parsing class""" - __slots__ = "_stream", "_buffer", "_read_size" + __slots__ = "_stream", "_read_size" EXCEPTION_CLASSES: ExceptionMappingT = { "ERR": { @@ -171,7 +171,6 @@ class BaseParser: def __init__(self, socket_read_size: int): self._stream: Optional[asyncio.StreamReader] = None - self._buffer: Optional[SocketBuffer] = None self._read_size = socket_read_size def __del__(self): @@ -206,54 +205,6 @@ async def read_response( raise NotImplementedError() -class SocketBuffer: - """Async-friendly re-impl of redis-py's SocketBuffer.""" - - def __init__( - self, - stream_reader: asyncio.StreamReader, - socket_read_size: int, - ): - self._stream: Optional[asyncio.StreamReader] = stream_reader - - async def can_read_destructive(self) -> bool: - if self._stream is None: - raise RedisError("Buffer is closed.") - try: - async with async_timeout.timeout(0): - return await self._stream.read(1) - except asyncio.TimeoutError: - return False - - async def read(self, length: int) -> bytes: - """ - Read `length` bytes of data. These are assumed to be followed - by a '\r\n' terminator which is subsequently discarded. - """ - if self._stream is None: - raise RedisError("Buffer is closed.") - try: - data = await self._stream.readexactly(length + 2) - except asyncio.IncompleteReadError as error: - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error - return data[:-2] - - async def readline(self) -> bytes: - """ - read an unknown number of bytes up to the next '\r\n' - line separator, which is discarded. - """ - if self._stream is None: - raise RedisError("Buffer is closed.") - data = await self._stream.readline() - if not data.endswith(b"\r\n"): - raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - return data[:-2] - - def close(self): - self._stream = None - - class PythonParser(BaseParser): """Plain Python parsing class""" From 5a3f7b17ab61ab6366621ccf46a7a5638f3e06f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sat, 15 Oct 2022 16:39:33 +0000 Subject: [PATCH 4/5] Code cleanup --- redis/asyncio/connection.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index c37796ee59..b64bd125eb 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -2,7 +2,6 @@ import copy import enum import inspect -import io import os import socket import ssl @@ -183,7 +182,7 @@ def parse_error(self, response: str) -> ResponseError: """Parse an error response""" error_code = response.split(" ")[0] if error_code in self.EXCEPTION_CLASSES: - response = response[len(error_code) + 1:] + response = response[len(error_code) + 1 :] exception_class = self.EXCEPTION_CLASSES[error_code] if isinstance(exception_class, dict): exception_class = exception_class.get(response, ResponseError) From f4aee4c6189674f0161c14de19e1f01c1f682ee5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sat, 15 Oct 2022 17:07:50 +0000 Subject: [PATCH 5/5] Fix unittest mocking when SocketBuffer is gone --- tests/test_asyncio/test_connection.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 674a1b9980..6bf0034146 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -13,22 +13,23 @@ from redis.asyncio.retry import Retry from redis.backoff import NoBackoff from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError -from redis.utils import HIREDIS_AVAILABLE from tests.conftest import skip_if_server_version_lt from .compat import mock @pytest.mark.onlynoncluster -@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") async def test_invalid_response(create_redis): r = await create_redis(single_connection_client=True) raw = b"x" - readline_mock = mock.AsyncMock(return_value=raw) parser: "PythonParser" = r.connection._parser - with mock.patch.object(parser._buffer, "readline", readline_mock): + if not isinstance(parser, PythonParser): + pytest.skip("PythonParser only") + stream_mock = mock.Mock(parser._stream) + stream_mock.readline.return_value = raw + b"\r\n" + with mock.patch.object(parser, "_stream", stream_mock): with pytest.raises(InvalidResponse) as cm: await parser.read_response() assert str(cm.value) == f"Protocol Error: {raw!r}"