Skip to content

Remove the superflous SocketBuffer from asyncio PythonParser #2418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 37 additions & 133 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import copy
import enum
import inspect
import io
import os
import socket
import ssl
Expand Down Expand Up @@ -141,7 +140,7 @@ def decode(self, value: EncodableT, force=False) -> EncodableT:
class BaseParser:
"""Plain Python parsing class"""

__slots__ = "_stream", "_buffer", "_read_size"
__slots__ = "_stream", "_read_size"

EXCEPTION_CLASSES: ExceptionMappingT = {
"ERR": {
Expand Down Expand Up @@ -171,7 +170,6 @@ class BaseParser:

def __init__(self, socket_read_size: int):
self._stream: Optional[asyncio.StreamReader] = None
self._buffer: Optional[SocketBuffer] = None
self._read_size = socket_read_size

def __del__(self):
Expand Down Expand Up @@ -206,127 +204,6 @@ async def read_response(
raise NotImplementedError()


class SocketBuffer:
"""Async-friendly re-impl of redis-py's SocketBuffer.

TODO: We're currently passing through two buffers,
the asyncio.StreamReader and this. I imagine we can reduce the layers here
while maintaining compliance with prior art.
"""

def __init__(
self,
stream_reader: asyncio.StreamReader,
socket_read_size: int,
):
self._stream: Optional[asyncio.StreamReader] = stream_reader
self.socket_read_size = socket_read_size
self._buffer: Optional[io.BytesIO] = io.BytesIO()
# number of bytes written to the buffer from the socket
self.bytes_written = 0
# number of bytes read from the buffer
self.bytes_read = 0

@property
def length(self):
return self.bytes_written - self.bytes_read

async def _read_from_socket(self, length: Optional[int] = None) -> bool:
buf = self._buffer
if buf is None or self._stream is None:
raise RedisError("Buffer is closed.")
buf.seek(self.bytes_written)
marker = 0

while True:
data = await self._stream.read(self.socket_read_size)
# an empty string indicates the server shutdown the socket
if isinstance(data, bytes) and len(data) == 0:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
buf.write(data)
data_length = len(data)
self.bytes_written += data_length
marker += data_length

if length is not None and length > marker:
continue
return True

async def can_read_destructive(self) -> bool:
if self.length:
return True
try:
async with async_timeout.timeout(0):
return await self._read_from_socket()
except asyncio.TimeoutError:
return False

async def read(self, length: int) -> bytes:
length = length + 2 # make sure to read the \r\n terminator
# make sure we've read enough data from the socket
if length > self.length:
await self._read_from_socket(length - self.length)

if self._buffer is None:
raise RedisError("Buffer is closed.")

self._buffer.seek(self.bytes_read)
data = self._buffer.read(length)
self.bytes_read += len(data)

# purge the buffer when we've consumed it all so it doesn't
# grow forever
if self.bytes_read == self.bytes_written:
self.purge()

return data[:-2]

async def readline(self) -> bytes:
buf = self._buffer
if buf is None:
raise RedisError("Buffer is closed.")

buf.seek(self.bytes_read)
data = buf.readline()
while not data.endswith(SYM_CRLF):
# there's more data in the socket that we need
await self._read_from_socket()
buf.seek(self.bytes_read)
data = buf.readline()

self.bytes_read += len(data)

# purge the buffer when we've consumed it all so it doesn't
# grow forever
if self.bytes_read == self.bytes_written:
self.purge()

return data[:-2]

def purge(self):
if self._buffer is None:
raise RedisError("Buffer is closed.")

self._buffer.seek(0)
self._buffer.truncate()
self.bytes_written = 0
self.bytes_read = 0

def close(self):
try:
self.purge()
self._buffer.close()
except Exception:
# issue #633 suggests the purge/close somehow raised a
# BadFileDescriptor error. Perhaps the client ran out of
# memory or something else? It's probably OK to ignore
# any error being raised from purge/close since we're
# removing the reference to the instance below.
pass
self._buffer = None
self._stream = None


class PythonParser(BaseParser):
"""Plain Python parsing class"""

Expand All @@ -342,27 +219,29 @@ def on_connect(self, connection: "Connection"):
if self._stream is None:
raise RedisError("Buffer is closed.")

self._buffer = SocketBuffer(self._stream, self._read_size)
self.encoder = connection.encoder

def on_disconnect(self):
"""Called when the stream disconnects"""
if self._stream is not None:
self._stream = None
if self._buffer is not None:
self._buffer.close()
self._buffer = None
self.encoder = None

async def can_read_destructive(self):
return self._buffer and bool(await self._buffer.can_read_destructive())
async def can_read_destructive(self) -> bool:
if self._stream is None:
raise RedisError("Buffer is closed.")
try:
async with async_timeout.timeout(0):
return await self._stream.read(1)
except asyncio.TimeoutError:
return False

async def read_response(
self, disable_decoding: bool = False
) -> Union[EncodableT, ResponseError, None]:
if not self._buffer or not self.encoder:
if not self._stream or not self.encoder:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
raw = await self._buffer.readline()
raw = await self._readline()
if not raw:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
response: Any
Expand Down Expand Up @@ -395,7 +274,7 @@ async def read_response(
length = int(response)
if length == -1:
return None
response = await self._buffer.read(length)
response = await self._read(length)
# multi-bulk response
elif byte == b"*":
length = int(response)
Expand All @@ -408,6 +287,31 @@ async def read_response(
response = self.encoder.decode(response)
return response

async def _read(self, length: int) -> bytes:
"""
Read `length` bytes of data. These are assumed to be followed
by a '\r\n' terminator which is subsequently discarded.
"""
if self._stream is None:
raise RedisError("Buffer is closed.")
try:
data = await self._stream.readexactly(length + 2)
except asyncio.IncompleteReadError as error:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from error
return data[:-2]

async def _readline(self) -> bytes:
"""
read an unknown number of bytes up to the next '\r\n'
line separator, which is discarded.
"""
if self._stream is None:
raise RedisError("Buffer is closed.")
data = await self._stream.readline()
if not data.endswith(b"\r\n"):
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
return data[:-2]


class HiredisParser(BaseParser):
"""Parser class for connections using Hiredis"""
Expand Down
9 changes: 5 additions & 4 deletions tests/test_asyncio/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,23 @@
from redis.asyncio.retry import Retry
from redis.backoff import NoBackoff
from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError
from redis.utils import HIREDIS_AVAILABLE
from tests.conftest import skip_if_server_version_lt

from .compat import mock


@pytest.mark.onlynoncluster
@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
async def test_invalid_response(create_redis):
r = await create_redis(single_connection_client=True)

raw = b"x"
readline_mock = mock.AsyncMock(return_value=raw)

parser: "PythonParser" = r.connection._parser
with mock.patch.object(parser._buffer, "readline", readline_mock):
if not isinstance(parser, PythonParser):
pytest.skip("PythonParser only")
stream_mock = mock.Mock(parser._stream)
stream_mock.readline.return_value = raw + b"\r\n"
with mock.patch.object(parser, "_stream", stream_mock):
with pytest.raises(InvalidResponse) as cm:
await parser.read_response()
assert str(cm.value) == f"Protocol Error: {raw!r}"
Expand Down