diff --git a/tests/resp.py b/tests/resp.py new file mode 100644 index 0000000000..c0b8895527 --- /dev/null +++ b/tests/resp.py @@ -0,0 +1,531 @@ +import itertools +from contextlib import closing +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union + +CRNL = b"\r\n" + + +class VerbatimStr(str): + """ + A string that is encoded as a resp3 verbatim string + """ + + def __new__(cls, value: str, hint: str) -> "VerbatimStr": + return str.__new__(cls, value) + + def __init__(self, value: str, hint: str) -> None: + self.hint = hint + + def __repr__(self) -> str: + return f"VerbatimStr({super().__repr__()}, {self.hint!r})" + + +class ErrorStr(str): + """ + A string to be encoded as a resp3 error + """ + + def __new__(cls, code: str, value: str) -> "ErrorStr": + return str.__new__(cls, value) + + def __init__(self, code: str, value: str) -> None: + self.code = code.upper() + + def __repr__(self) -> str: + return f"ErrorString({self.code!r}, {super().__repr__()})" + + def __str__(self) -> str: + return f"{self.code} {super().__str__()}" + + +class PushData(List[Any]): + """ + A special type of list indicating data from a push response + """ + + def __repr__(self) -> str: + return f"PushData({super().__repr__()})" + + +class Attribute(Dict[Any, Any]): + """ + A special type of map indicating data from a attribute response + """ + + def __repr__(self) -> str: + return f"Attribute({super().__repr__()})" + + +class RespEncoder: + """ + A class for simple RESP protocol encoder for unit tests + """ + + def __init__( + self, protocol: int = 2, encoding: str = "utf-8", errorhander: str = "strict" + ) -> None: + self.protocol = protocol + self.encoding = encoding + self.errorhandler = errorhander + + def apply_encoding(self, value: str) -> bytes: + return value.encode(self.encoding, errors=self.errorhandler) + + def has_crnl(self, value: bytes) -> bool: + """check if either cr or nl is in the value""" + return b"\r" in value or b"\n" in value + + def escape_crln(self, value: bytes) -> bytes: + """remove any cr or nl from the value""" + return value.replace(b"\r", b"\\r").replace(b"\n", b"\\n") + + def encode(self, data: Any, hint: Optional[str] = None) -> bytes: + if isinstance(data, dict): + if self.protocol > 2: + code = "|" if isinstance(data, Attribute) else "%" + result = f"{code}{len(data)}\r\n".encode() + for key, val in data.items(): + result += self.encode(key) + self.encode(val) + return result + else: + # Automatically encode dicts as flattened key, value arrays + mylist = list( + itertools.chain(*((key, val) for (key, val) in data.items())) + ) + return self.encode(mylist) + + elif isinstance(data, list): + code = ">" if isinstance(data, PushData) and self.protocol > 2 else "*" + result = f"{code}{len(data)}\r\n".encode() + for val in data: + result += self.encode(val) + return result + + elif isinstance(data, set): + if self.protocol > 2: + result = f"~{len(data)}\r\n".encode() + for val in data: + result += self.encode(val) + return result + else: + return self.encode(list(data)) + + elif isinstance(data, ErrorStr): + enc = self.apply_encoding(str(data)) + if self.protocol > 2: + if len(enc) > 80 or self.has_crnl(enc): + return f"!{len(enc)}\r\n".encode() + enc + b"\r\n" + return b"-" + self.escape_crln(enc) + b"\r\n" + + elif isinstance(data, str): + enc = self.apply_encoding(data) + # long strings or strings with control characters must be encoded as bulk + # strings + if hint or len(enc) > 80 or self.has_crnl(enc): + return self.encode_bulkstr(enc, hint) + return b"+" + enc + b"\r\n" + + elif isinstance(data, bytes): + return self.encode_bulkstr(data, hint) + + elif isinstance(data, bool): + if self.protocol == 2: + return b":1\r\n" if data else b":0\r\n" + return b"t\r\n" if data else b"f\r\n" + + elif isinstance(data, int): + if (data > 2**63 - 1) or (data < -(2**63)): + if self.protocol > 2: + return f"({data}\r\n".encode() # resp3 big int + return f"+{data}\r\n".encode() # force to simple string + return f":{data}\r\n".encode() + elif isinstance(data, float): + if self.protocol > 2: + return f",{data}\r\n".encode() # resp3 double + return f"+{data}\r\n".encode() # simple string + + elif data is None: + if self.protocol > 2: + return b"_\r\n" # resp3 null + return b"$-1\r\n" # Null bulk string + # some commands return null array: b"*-1\r\n" + + else: + raise NotImplementedError(f"encode not implemented for {type(data)}") + + def encode_bulkstr(self, bstr: bytes, hint: Optional[str]) -> bytes: + if self.protocol > 2 and hint is not None: + # a resp3 verbatim string + return f"={len(bstr)}\r\n{hint}:".encode() + bstr + b"\r\n" + # regular bulk string + return f"${len(bstr)}\r\n".encode() + bstr + b"\r\n" + + +def encode(value: Any, protocol: int = 2, hint: Optional[str] = None) -> bytes: + """ + Encode a value using the RESP protocol + """ + return RespEncoder(protocol).encode(value, hint) + + +class RespGeneratorParser: + """ + A wrapper class around a stateful RESP parsing generator, + allowing custom string decoding rules. + """ + + def __init__(self, encoding: str = "utf-8", errorhandler: str = "surrogateescape"): + """ + Create a new parser, optionally specifying the encoding and errorhandler. + If `encoding` is None, bytes will be returned as-is. + The default settings are utf-8 encoding and surrogateescape errorhandler, + which can decode all possible byte sequences, + allowing decoded data to be re-encoded back to bytes. + """ + self.encoding = encoding + self.errorhandler = errorhandler + + def decode_bytes(self, data: bytes) -> str: + """ + decode the data as a string, + """ + return data.decode(self.encoding, errors=self.errorhandler) + + def parse( + self, + buffer: bytes, + ) -> Generator[Optional[Tuple[Any, bytes]], Union[None, bytes], None]: + """ + A stateful, generator based, RESP parser. + Returns a generator producing at most a single top-level primitive. + Yields tuple of (data_item, unparsed), or None if more data is needed. + It is fed more data with generator.send() + """ + # Read the first line of resp or yield to get more data + while CRNL not in buffer: + incoming = yield None + assert incoming is not None + buffer += incoming + cmd, rest = buffer.split(CRNL, 1) + + code, arg = cmd[:1], cmd[1:] + + if code == b":" or code == b"(": # integer, resp3 large int + yield int(arg), rest + + elif code == b"t": # resp3 true + yield True, rest + + elif code == b"f": # resp3 false + yield False, rest + + elif code == b"_": # resp3 null + yield None, rest + + elif code == b",": # resp3 double + yield float(arg), rest + + elif code == b"+": # simple string + # we decode them automatically + yield self.decode_bytes(arg), rest + + elif code == b"$": # bulk string + count = int(arg) + expect = count + 2 # +2 for the trailing CRNL + while len(rest) < expect: + incoming = yield (None) + assert incoming is not None + rest += incoming + bulkstr = rest[:count] + yield self.decode_bytes(bulkstr), rest[expect:] + + elif code == b"=": # verbatim strings + count = int(arg) + expect = count + 4 + 2 # 4 type and colon +2 for the trailing CRNL + while len(rest) < expect: + incoming = yield (None) + assert incoming is not None + rest += incoming + string = self.decode_bytes(rest[: (count + 4)]) + if string[3] != ":": + raise ValueError(f"Expected colon after hint, got {string[3]}") + hint = string[:3] + string = string[4 : (count + 4)] + yield VerbatimStr(string, hint), rest[expect:] + + elif code in b"*>": # array or push data + count = int(arg) + result_array = [] + for _ in range(count): + # recursively parse the next array item + with closing(self.parse(rest)) as parser: + parsed = parser.send(None) + while parsed is None: + incoming = yield None + parsed = parser.send(incoming) + value, rest = parsed + result_array.append(value) + if code == b">": + yield PushData(result_array), rest + else: + yield result_array, rest + + elif code == b"~": # set + count = int(arg) + result_set = set() + for _ in range(count): + # recursively parse the next set item + with closing(self.parse(rest)) as parser: + parsed = parser.send(None) + while parsed is None: + incoming = yield None + parsed = parser.send(incoming) + value, rest = parsed + result_set.add(value) + yield result_set, rest + + elif code in b"%|": # map or attribute + count = int(arg) + result_map = {} + for _ in range(count): + # recursively parse the next key, and value + with closing(self.parse(rest)) as parser: + parsed = parser.send(None) + while parsed is None: + incoming = yield None + parsed = parser.send(incoming) + key, rest = parsed + with closing(self.parse(rest)) as parser: + parsed = parser.send(None) + while parsed is None: + incoming = yield None + parsed = parser.send(incoming) + value, rest = parsed + result_map[key] = value + if code == b"|": + yield Attribute(result_map), rest + yield result_map, rest + + elif code == b"-": # error + # we decode them automatically + decoded = self.decode_bytes(arg) + assert isinstance(decoded, str) + err, value = decoded.split(" ", 1) + yield ErrorStr(err, value), rest + + elif code == b"!": # resp3 error + count = int(arg) + expect = count + 2 # +2 for the trailing CRNL + while len(rest) < expect: + incoming = yield (None) + assert incoming is not None + rest += incoming + bulkstr = rest[:count] + decoded = self.decode_bytes(bulkstr) + assert isinstance(decoded, str) + err, value = decoded.split(" ", 1) + yield ErrorStr(err, value), rest[expect:] + + else: + raise ValueError(f"Unknown opcode '{code.decode()}'") + + +class NeedMoreData(RuntimeError): + """ + Raised when more data is needed to complete a parse + """ + + +class RespParser: + """ + A class for simple RESP protocol decoding for unit tests. + Uses a RespGeneratorParser to produce data, and can + produce top-level objects for as long as there is data available. + """ + + def __init__(self) -> None: + self.parser = RespGeneratorParser() + self.generator: Optional[ + Generator[Optional[Tuple[Any, bytes]], Union[None, bytes], None] + ] = None + # which has not resulted in a parsed value + self.consumed: List[bytes] = [] + + def parse(self, buffer: bytes) -> Optional[Any]: + """ + Parse a buffer of data, return a tuple of a single top-level primitive and the + remaining buffer or raise NeedMoreData if more data is needed to + produce a value. + """ + if self.generator is None: + # create a new parser generator, initializing it with + # any unparsed data from previous calls + buffer = b"".join(self.consumed) + buffer + self.consumed.clear() + self.generator = self.parser.parse(buffer) + parsed = self.generator.send(None) + else: + # sen more data to the parser + parsed = self.generator.send(buffer) + + if parsed is None: + self.consumed.append(buffer) + raise NeedMoreData() + + # got a value, close the generator, store the remaining buffer + self.generator.close() + self.generator = None + value, remaining = parsed + self.consumed = [remaining] + return value + + def get_unparsed(self) -> bytes: + return b"".join(self.consumed) + + def close(self) -> None: + if self.generator is not None: + self.generator.close() + self.generator = None + self.consumed.clear() + + +def parse_all(buffer: bytes) -> Tuple[List[Any], bytes]: + """ + Parse all the data in the buffer, returning the list of top-level objects and the + remaining buffer + """ + with closing(RespParser()) as parser: + result: List[Any] = [] + while True: + try: + result.append(parser.parse(buffer)) + buffer = b"" + except NeedMoreData: + return result, parser.get_unparsed() + + +def parse_chunks(buffers: List[bytes]) -> Tuple[List[Any], bytes]: + """ + Parse all the data in the buffers, returning the list of top-level objects and the + remaining buffer. + Used primarily for testing, since it will parse the data in chunks + """ + result: List[Any] = [] + with closing(RespParser()) as parser: + for buffer in buffers: + while True: + try: + result.append(parser.parse(buffer)) + buffer = b"" + except NeedMoreData: + break + return result, parser.get_unparsed() + + +class RespServer: + """A simple, dummy, REDIS server for unit tests. + Accepts RESP commands and returns RESP responses. + """ + + handlers: Dict[str, Callable[..., Any]] = {} + + def __init__(self) -> None: + self.protocol = 2 + self.server_ver = self.get_server_version() + self.auth: List[Any] = [] + self.client_name = "" + + # patchable methods for testing + + def get_server_version(self) -> int: + return 6 + + def on_auth(self, auth: List[Any]) -> None: + pass + + def on_setname(self, name: str) -> None: + pass + + def on_protocol(self, proto: int) -> None: + pass + + def command(self, cmd: Any) -> bytes: + """Process a single command and return the response""" + result = self._command(cmd) + return RespEncoder(self.protocol).encode(result) + + def _command(self, cmd: Any) -> Any: + if not isinstance(cmd, list): + return ErrorStr("ERR", "unknown command {cmd!r}") + + # handle registered commands + command = cmd[0].upper() + args = cmd[1:] + if command in self.handlers: + return self.handlers[command](self, args) + + return ErrorStr("ERR", "unknown command {cmd!r}") + + def handle_auth(self, args: List[Any]) -> Union[str, ErrorStr]: + self.auth = args[:] + self.on_auth(self.auth) + expect = 2 if self.server_ver >= 6 else 1 + if len(args) != expect: + return ErrorStr("ERR", "wrong number of arguments" " for 'AUTH' command") + return "OK" + + handlers["AUTH"] = handle_auth + + def handle_client(self, args: List[Any]) -> Union[str, ErrorStr]: + if args[0] == "SETNAME": + return self.handle_setname(args[1:]) + return ErrorStr("ERR", "unknown subcommand or wrong number of arguments") + + handlers["CLIENT"] = handle_client + + def handle_setname(self, args: List[Any]) -> Union[str, ErrorStr]: + if len(args) != 1: + return ErrorStr("ERR", "wrong number of arguments") + self.client_name = args[0] + self.on_setname(self.client_name) + return "OK" + + def handle_hello(self, args: List[Any]) -> Union[ErrorStr, Dict[str, Any]]: + if self.server_ver < 6: + return ErrorStr("ERR", "unknown command 'HELLO'") + proto = self.protocol + if args: + proto = args.pop(0) + if str(proto) not in ["2", "3"]: + return ErrorStr( + "NOPROTO", "sorry this protocol version is not supported" + ) + + while args: + cmd = args.pop(0).upper() + if cmd == "AUTH": + auth_args = args[:2] + args = args[2:] + res = self.handle_auth(auth_args) + if isinstance(res, ErrorStr): + return res + continue + if cmd == "SETNAME": + setname_args = args[:1] + args = args[1:] + res = self.handle_setname(setname_args) + if isinstance(res, ErrorStr): + return res + continue + return ErrorStr("ERR", "unknown subcommand or wrong number of arguments") + + self.protocol = int(proto) + self.on_protocol(self.protocol) + result = { + "server": "redistester", + "version": "0.0.1", + "proto": self.protocol, + } + return result + + handlers["HELLO"] = handle_hello diff --git a/tests/test_asyncio/test_connect.py b/tests/test_asyncio/test_connect.py index 5e6b120fb3..b59310578a 100644 --- a/tests/test_asyncio/test_connect.py +++ b/tests/test_asyncio/test_connect.py @@ -1,26 +1,24 @@ import asyncio import logging -import re import socket import ssl +from unittest.mock import patch import pytest from redis.asyncio.connection import ( Connection, + ResponseError, SSLConnection, UnixDomainSocketConnection, ) +from .. import resp from ..ssl_utils import get_ssl_filename _logger = logging.getLogger(__name__) _CLIENT_NAME = "test-suite-client" -_CMD_SEP = b"\r\n" -_SUCCESS_RESP = b"+OK" + _CMD_SEP -_ERROR_RESP = b"-ERR" + _CMD_SEP -_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP} @pytest.fixture @@ -65,6 +63,90 @@ async def test_tcp_ssl_connect(tcp_address): await conn.disconnect() +@pytest.mark.parametrize( + ("use_server_ver", "use_protocol", "use_auth", "use_client_name"), + [ + (5, 2, False, True), + (5, 2, True, True), + (5, 3, True, True), + (6, 2, False, True), + (6, 2, True, True), + (6, 3, False, False), + (6, 3, True, False), + (6, 3, False, True), + (6, 3, True, True), + ], +) +# @pytest.mark.parametrize("use_protocol", [2, 3]) +# @pytest.mark.parametrize("use_auth", [False, True]) +async def test_tcp_auth( + tcp_address, use_protocol, use_auth, use_server_ver, use_client_name +): + """ + Test that various initial handshake cases are handled correctly by the client + """ + got_auth = [] + got_protocol = None + got_name = None + + def on_auth(self, auth): + got_auth[:] = auth + + def on_protocol(self, proto): + nonlocal got_protocol + got_protocol = proto + + def on_setname(self, name): + nonlocal got_name + got_name = name + + def get_server_version(self): + return use_server_ver + + if use_auth: + auth_args = {"username": "myuser", "password": "mypassword"} + else: + auth_args = {} + got_protocol = None + host, port = tcp_address + conn = Connection( + host=host, + port=port, + client_name=_CLIENT_NAME if use_client_name else None, + socket_timeout=10, + protocol=use_protocol, + **auth_args, + ) + try: + with patch.multiple( + resp.RespServer, + on_auth=on_auth, + get_server_version=get_server_version, + on_protocol=on_protocol, + on_setname=on_setname, + ): + if use_server_ver < 6 and use_protocol > 2: + with pytest.raises(ResponseError): + await _assert_connect(conn, tcp_address) + return + + await _assert_connect(conn, tcp_address) + if use_protocol == 3: + assert got_protocol == use_protocol + if use_auth: + if use_server_ver < 6: + assert got_auth == ["mypassword"] + else: + assert got_auth == ["myuser", "mypassword"] + + if use_client_name: + assert got_name == _CLIENT_NAME + else: + assert got_name is None + finally: + await conn.disconnect() + + async def _assert_connect(conn, server_address, certfile=None, keyfile=None): stop_event = asyncio.Event() finished = asyncio.Event() @@ -102,46 +184,34 @@ async def _handler(reader, writer): async def _redis_request_handler(reader, writer, stop_event): + parser = resp.RespParser() + server = resp.RespServer() buffer = b"" - command = None - command_ptr = None - fragment_length = None - while not stop_event.is_set() or buffer: - _logger.info(str(stop_event.is_set())) - try: - buffer += await asyncio.wait_for(reader.read(1024), timeout=0.5) - except TimeoutError: - continue - if not buffer: - continue - parts = re.split(_CMD_SEP, buffer) - buffer = parts[-1] - for fragment in parts[:-1]: - fragment = fragment.decode() - _logger.info("Command fragment: %s", fragment) - - if fragment.startswith("*") and command is None: - command = [None for _ in range(int(fragment[1:]))] - command_ptr = 0 - fragment_length = None - continue - - if fragment.startswith("$") and command[command_ptr] is None: - fragment_length = int(fragment[1:]) - continue - - assert len(fragment) == fragment_length - command[command_ptr] = fragment - command_ptr += 1 - - if command_ptr < len(command): + try: + # if client performs pipelining, we may need + # to adjust this code to not block when sending + # responses. + while not stop_event.is_set() or buffer: + _logger.info(str(stop_event.is_set())) + try: + command = parser.parse(buffer) + buffer = b"" + except resp.NeedMoreData: + try: + buffer = await asyncio.wait_for(reader.read(1024), timeout=0.5) + except TimeoutError: + buffer = b"" + continue + if not buffer: + break # EOF continue - command = " ".join(command) _logger.info("Command %s", command) - resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP) - _logger.info("Response from %s", resp) - writer.write(resp) + response = server.command(command) + _logger.info("Response %s", response) + writer.write(response) await writer.drain() - command = None - _logger.info("Exit handler") + except Exception: + _logger.exception("Error in handler") + finally: + _logger.info("Exit handler") diff --git a/tests/test_connect.py b/tests/test_connect.py index 696e69ceea..49c3abe506 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -1,23 +1,25 @@ import logging -import re import socket import socketserver import ssl import threading +from unittest.mock import patch import pytest -from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection - +from redis.connection import ( + Connection, + ResponseError, + SSLConnection, + UnixDomainSocketConnection, +) + +from . import resp from .ssl_utils import get_ssl_filename _logger = logging.getLogger(__name__) _CLIENT_NAME = "test-suite-client" -_CMD_SEP = b"\r\n" -_SUCCESS_RESP = b"+OK" + _CMD_SEP -_ERROR_RESP = b"-ERR" + _CMD_SEP -_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP} @pytest.fixture @@ -59,6 +61,88 @@ def test_tcp_ssl_connect(tcp_address): _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile) +@pytest.mark.parametrize( + ("use_server_ver", "use_protocol", "use_auth", "use_client_name"), + [ + (5, 2, False, True), + (5, 2, True, True), + (5, 3, True, True), + (6, 2, False, True), + (6, 2, True, True), + (6, 3, False, False), + (6, 3, True, False), + (6, 3, False, True), + (6, 3, True, True), + ], +) +# @pytest.mark.parametrize("use_protocol", [2, 3]) +# @pytest.mark.parametrize("use_auth", [False, True]) +def test_tcp_auth(tcp_address, use_protocol, use_auth, use_server_ver, use_client_name): + """ + Test that various initial handshake cases are handled correctly by the client + """ + got_auth = [] + got_protocol = None + got_name = None + + def on_auth(self, auth): + got_auth[:] = auth + + def on_protocol(self, proto): + nonlocal got_protocol + got_protocol = proto + + def on_setname(self, name): + nonlocal got_name + got_name = name + + def get_server_version(self): + return use_server_ver + + if use_auth: + auth_args = {"username": "myuser", "password": "mypassword"} + else: + auth_args = {} + got_protocol = None + host, port = tcp_address + conn = Connection( + host=host, + port=port, + client_name=_CLIENT_NAME if use_client_name else None, + socket_timeout=10, + protocol=use_protocol, + **auth_args, + ) + try: + with patch.multiple( + resp.RespServer, + on_auth=on_auth, + get_server_version=get_server_version, + on_protocol=on_protocol, + on_setname=on_setname, + ): + if use_server_ver < 6 and use_protocol > 2: + with pytest.raises(ResponseError): + _assert_connect(conn, tcp_address) + return + + _assert_connect(conn, tcp_address) + if use_protocol == 3: + assert got_protocol == use_protocol + if use_auth: + if use_server_ver < 6: + assert got_auth == ["mypassword"] + else: + assert got_auth == ["myuser", "mypassword"] + + if use_client_name: + assert got_name == _CLIENT_NAME + else: + assert got_name is None + finally: + conn.disconnect() + + def _assert_connect(conn, server_address, certfile=None, keyfile=None): if isinstance(server_address, str): if not _RedisUDSServer: @@ -148,44 +232,31 @@ def finish(self): _logger.info("%s disconnected", self.client_address) def handle(self): + parser = resp.RespParser() + server = resp.RespServer() buffer = b"" - command = None - command_ptr = None - fragment_length = None - while self.server.is_serving() or buffer: - try: - buffer += self.request.recv(1024) - except socket.timeout: - continue - if not buffer: - continue - parts = re.split(_CMD_SEP, buffer) - buffer = parts[-1] - for fragment in parts[:-1]: - fragment = fragment.decode() - _logger.info("Command fragment: %s", fragment) - - if fragment.startswith("*") and command is None: - command = [None for _ in range(int(fragment[1:]))] - command_ptr = 0 - fragment_length = None - continue - - if fragment.startswith("$") and command[command_ptr] is None: - fragment_length = int(fragment[1:]) - continue - - assert len(fragment) == fragment_length - command[command_ptr] = fragment - command_ptr += 1 - - if command_ptr < len(command): + try: + # if client performs pipelining, we may need + # to adjust this code to not block when sending + # responses. + while self.server.is_serving(): + try: + command = parser.parse(buffer) + buffer = b"" + except resp.NeedMoreData: + try: + buffer = self.request.recv(1024) + except socket.timeout: + buffer = b"" + continue + if not buffer: + break # EOF continue - - command = " ".join(command) _logger.info("Command %s", command) - resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP) - _logger.info("Response %s", resp) - self.request.sendall(resp) - command = None - _logger.info("Exit handler") + response = server.command(command) + _logger.info("Response %s", response) + self.request.sendall(response) + except Exception: + _logger.exception("Exception in handler") + finally: + _logger.info("Exit handler") diff --git a/tests/test_resp.py b/tests/test_resp.py new file mode 100644 index 0000000000..4706699a4f --- /dev/null +++ b/tests/test_resp.py @@ -0,0 +1,237 @@ +import pytest + +from .resp import ( + Attribute, + ErrorStr, + PushData, + VerbatimStr, + encode, + parse_all, + parse_chunks, +) + + +@pytest.fixture(params=[2, 3]) +def resp_version(request): + return request.param + + +class TestEncoder: + def test_simple_str(self): + assert encode("foo") == b"+foo\r\n" + + def test_long_str(self): + text = 3 * "fooling around with the sword in the mud" + assert len(text) == 120 + assert encode(text) == b"$120\r\n" + text.encode() + b"\r\n" + + # test strings with control characters + def test_str_with_ctrl_chars(self): + text = "foo\r\nbar" + assert encode(text) == b"$8\r\nfoo\r\nbar\r\n" + + def test_bytes(self): + assert encode(b"foo") == b"$3\r\nfoo\r\n" + + def test_int(self): + assert encode(123) == b":123\r\n" + + def test_float(self, resp_version): + data = encode(1.23, protocol=resp_version) + if resp_version == 2: + assert data == b"+1.23\r\n" + else: + assert data == b",1.23\r\n" + + def test_large_int(self, resp_version): + data = encode(2**63, protocol=resp_version) + if resp_version == 2: + assert data == b"+9223372036854775808\r\n" + else: + assert data == b"(9223372036854775808\r\n" + + def test_array(self): + assert encode([1, 2, 3]) == b"*3\r\n:1\r\n:2\r\n:3\r\n" + + def test_push_data(self, resp_version): + data = encode(PushData([1, 2, 3]), protocol=resp_version) + if resp_version == 2: + assert data == b"*3\r\n:1\r\n:2\r\n:3\r\n" + else: + assert data == b">3\r\n:1\r\n:2\r\n:3\r\n" + + def test_set(self, resp_version): + data = encode({1, 2, 3}, protocol=resp_version) + if resp_version == 2: + assert data == b"*3\r\n:1\r\n:2\r\n:3\r\n" + else: + assert data == b"~3\r\n:1\r\n:2\r\n:3\r\n" + + def test_map(self, resp_version): + data = encode({1: 2, 3: 4}, protocol=resp_version) + if resp_version == 2: + assert data == b"*4\r\n:1\r\n:2\r\n:3\r\n:4\r\n" + else: + assert data == b"%2\r\n:1\r\n:2\r\n:3\r\n:4\r\n" + + def test_attribute(self, resp_version): + data = encode(Attribute({1: 2, 3: 4}), protocol=resp_version) + if resp_version == 2: + assert data == b"*4\r\n:1\r\n:2\r\n:3\r\n:4\r\n" + else: + assert data == b"|2\r\n:1\r\n:2\r\n:3\r\n:4\r\n" + + def test_nested_array(self): + assert encode([1, [2, 3]]) == b"*2\r\n:1\r\n*2\r\n:2\r\n:3\r\n" + + def test_nested_map(self, resp_version): + data = encode({1: {2: 3}}, protocol=resp_version) + if resp_version == 2: + assert data == b"*2\r\n:1\r\n*2\r\n:2\r\n:3\r\n" + else: + assert data == b"%1\r\n:1\r\n%1\r\n:2\r\n:3\r\n" + + def test_null(self, resp_version): + data = encode(None, protocol=resp_version) + if resp_version == 2: + assert data == b"$-1\r\n" + else: + assert data == b"_\r\n" + + def test_mixed_array(self, resp_version): + data = encode([1, "foo", 2.3, None, True], protocol=resp_version) + if resp_version == 2: + assert data == b"*5\r\n:1\r\n+foo\r\n+2.3\r\n$-1\r\n:1\r\n" + else: + assert data == b"*5\r\n:1\r\n+foo\r\n,2.3\r\n_\r\nt\r\n" + + def test_bool(self, resp_version): + data = encode(True, protocol=resp_version) + if resp_version == 2: + assert data == b":1\r\n" + else: + assert data == b"t\r\n" + + data = encode(False, resp_version) + if resp_version == 2: + assert data == b":0\r\n" + else: + assert data == b"f\r\n" + + def test_errorstr(self, resp_version): + err = ErrorStr("foo", "bar\r\nbaz") + data = encode(err, protocol=resp_version) + if resp_version == 2: + assert data == b"-FOO bar\\r\\nbaz\r\n" + else: + assert data == b"!12\r\nFOO bar\r\nbaz\r\n" + + +@pytest.mark.parametrize("chunk_size", [0, 1, 2, -2]) +class TestParser: + def breakup_bytes(self, data, chunk_size=2): + insert_empty = False + if chunk_size < 0: + insert_empty = True + chunk_size = -chunk_size + chunks = [data[i : i + chunk_size] for i in range(0, len(data), chunk_size)] + if insert_empty: + empty = len(chunks) * [b""] + chunks = [item for pair in zip(chunks, empty) for item in pair] + return chunks + + def parse_data(self, chunk_size, data): + """helper to parse either a single blob, or a list of chunks""" + if chunk_size == 0: + return parse_all(data) + else: + return parse_chunks(self.breakup_bytes(data, chunk_size)) + + def test_int(self, chunk_size): + parsed = self.parse_data(chunk_size, b":123\r\n") + assert parsed == ([123], b"") + + parsed = self.parse_data(chunk_size, b":123\r\nfoo") + assert parsed == ([123], b"foo") + + def test_double(self, chunk_size): + parsed = self.parse_data(chunk_size, b",1.23\r\njunk") + assert parsed == ([1.23], b"junk") + + def test_array(self, chunk_size): + parsed = self.parse_data(chunk_size, b"*3\r\n:1\r\n:2\r\n:3\r\n") + assert parsed == ([[1, 2, 3]], b"") + + parsed = self.parse_data(chunk_size, b"*3\r\n:1\r\n:2\r\n:3\r\nfoo") + assert parsed == ([[1, 2, 3]], b"foo") + + def test_push_data(self, chunk_size): + parsed = self.parse_data(chunk_size, b">3\r\n:1\r\n:2\r\n:3\r\n") + assert isinstance(parsed[0][0], PushData) + assert parsed == ([[1, 2, 3]], b"") + + def test_incomplete_list(self, chunk_size): + parsed = self.parse_data(chunk_size, b"*3\r\n:1\r\n:2\r\n") + assert parsed == ([], b"*3\r\n:1\r\n:2\r\n") + + def test_invalid_token(self, chunk_size): + with pytest.raises(ValueError): + self.parse_data(chunk_size, b")foo\r\n") + with pytest.raises(ValueError): + self.parse_data(chunk_size, b"!foo\r\n") + + def test_multiple_ints(self, chunk_size): + parsed = self.parse_data(chunk_size, b":1\r\n:2\r\n:3\r\n") + assert parsed == ([1, 2, 3], b"") + + def test_multiple_ints_and_junk(self, chunk_size): + parsed = self.parse_data(chunk_size, b":1\r\n:2\r\n:3\r\n*3\r\n:1\r\n:2\r\n") + assert parsed == ([1, 2, 3], b"*3\r\n:1\r\n:2\r\n") + + def test_set(self, chunk_size): + parsed = self.parse_data(chunk_size, b"~3\r\n:1\r\n:2\r\n:3\r\n") + assert parsed == ([{1, 2, 3}], b"") + + def test_list_of_sets(self, chunk_size): + parsed = self.parse_data( + chunk_size, b"*2\r\n~3\r\n:1\r\n:2\r\n:3\r\n~2\r\n:4\r\n:5\r\n" + ) + assert parsed == ([[{1, 2, 3}, {4, 5}]], b"") + + def test_map(self, chunk_size): + parsed = self.parse_data(chunk_size, b"%2\r\n:1\r\n:2\r\n:3\r\n:4\r\n") + assert parsed == ([{1: 2, 3: 4}], b"") + + def test_simple_string(self, chunk_size): + parsed = self.parse_data(chunk_size, b"+foo\r\n") + assert parsed == (["foo"], b"") + + def test_bulk_string(self, chunk_size): + parsed = parse_all(b"$3\r\nfoo\r\nbar") + assert parsed == (["foo"], b"bar") + + def test_bulk_string_with_ctrl_chars(self, chunk_size): + parsed = self.parse_data(chunk_size, b"$8\r\nfoo\r\nbar\r\n") + assert parsed == (["foo\r\nbar"], b"") + + def test_verbatimstr(self, chunk_size): + parsed = self.parse_data(chunk_size, b"=3\r\ntxt:foo\r\nbar") + assert parsed == ([VerbatimStr("foo", "txt")], b"bar") + + def test_errorstr(self, chunk_size): + parsed = self.parse_data(chunk_size, b"-FOO bar\r\nbaz") + assert parsed == ([ErrorStr("foo", "bar")], b"baz") + + def test_errorstr_resp3(self, chunk_size): + parsed = self.parse_data(chunk_size, b"!12\r\nFOO bar\r\nbaz\r\n") + assert parsed == ([ErrorStr("foo", "bar\r\nbaz")], b"") + + def test_attribute_map(self, chunk_size): + parsed = self.parse_data(chunk_size, b"|2\r\n:1\r\n:2\r\n:3\r\n:4\r\n") + assert parsed == ([Attribute({1: 2, 3: 4})], b"") + + def test_surrogateescape(self, chunk_size): + data = b"foo\xff" + parsed = self.parse_data(chunk_size, b"$4\r\n" + data + b"\r\nbar") + assert parsed == ([data.decode(errors="surrogateescape")], b"bar") + assert parsed[0][0].encode("utf-8", "surrogateescape") == data