From 3ea72465ba4ad32df6eec286179813396b46a5f5 Mon Sep 17 00:00:00 2001 From: woutdenolf Date: Wed, 22 Mar 2023 17:01:04 +0100 Subject: [PATCH] tests: add 'connect' tests for all Redis connection classes --- tests/test_connect.py | 166 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 tests/test_connect.py diff --git a/tests/test_connect.py b/tests/test_connect.py new file mode 100644 index 0000000000..9f11872201 --- /dev/null +++ b/tests/test_connect.py @@ -0,0 +1,166 @@ +import logging +import re +import socket +import ssl +import threading + +import pytest + +from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection + +from .ssl_utils import get_ssl_filename + +_logger = logging.getLogger(__name__) + + +_CLIENT_NAME = "test-suite-client" +_CMD_SEP = b"\r\n" +_SUCCESS_RESP = b"+OK" + _CMD_SEP +_ERROR_RESP = b"-ERR" + _CMD_SEP +_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP} + + +@pytest.fixture +def tcp_address(): + with socket.socket() as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname() + + +@pytest.fixture +def uds_address(tmpdir): + return tmpdir / "uds.sock" + + +def test_tcp_connect(tcp_address): + host, port = tcp_address + conn = Connection(host=host, port=port, client_name=_CLIENT_NAME) + _assert_connect(conn, tcp_address) + + +def test_uds_connect(uds_address): + path = str(uds_address) + conn = UnixDomainSocketConnection(path, client_name=_CLIENT_NAME) + _assert_connect(conn, path) + + +@pytest.mark.ssl +def test_tcp_ssl_connect(tcp_address): + host, port = tcp_address + certfile = get_ssl_filename("server-cert.pem") + keyfile = get_ssl_filename("server-key.pem") + conn = SSLConnection( + host=host, port=port, client_name=_CLIENT_NAME, ssl_ca_certs=certfile + ) + _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile) + + +def _assert_connect(conn, server_address, certfile=None, keyfile=None): + ready = threading.Event() + stop = threading.Event() + t = threading.Thread( + target=_redis_mock_server, + args=(server_address, ready, stop), + kwargs={"certfile": certfile, "keyfile": keyfile}, + ) + t.start() + try: + ready.wait() + conn.connect() + conn.disconnect() + finally: + stop.set() + t.join(timeout=5) + + +def _redis_mock_server(server_address, ready, stop, certfile=None, keyfile=None): + try: + if isinstance(server_address, str): + family = socket.AF_UNIX + mockname = "Redis mock server (UDS)" + elif certfile: + family = socket.AF_INET + mockname = "Redis mock server (TCP-SSL)" + else: + family = socket.AF_INET + mockname = "Redis mock server (TCP)" + + with socket.socket(family, socket.SOCK_STREAM) as s: + s.bind(server_address) + s.listen(1) + s.settimeout(0.1) + + if certfile: + context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + context.minimum_version = ssl.TLSVersion.TLSv1_2 + context.load_cert_chain(certfile=certfile, keyfile=keyfile) + + _logger.info("Start %s: %s", mockname, server_address) + ready.set() + + # Wait for a client connection + while not stop.is_set(): + try: + sconn, _ = s.accept() + sconn.settimeout(0.1) + break + except socket.timeout: + pass + if stop.is_set(): + _logger.info("Exit %s: %s", mockname, server_address) + return + + # Handle commands from the client + with sconn: + if certfile: + with context.wrap_socket(sconn, server_side=True) as wconn: + _redis_mock_server_handle(wconn, stop, mockname) + else: + _redis_mock_server_handle(sconn, stop, mockname) + _logger.info("Exit %s: %s", mockname, server_address) + except BaseException as e: + _logger.exception("Error in %s: %s", mockname, e) + raise + + +def _redis_mock_server_handle(conn, stop, mockname): + buffer = b"" + command = None + command_ptr = None + fragment_length = None + while not stop.is_set() or buffer: + try: + buffer += conn.recv(1024) + except socket.timeout: + continue + if not buffer: + continue + parts = re.split(_CMD_SEP, buffer) + buffer = parts[-1] + for fragment in parts[:-1]: + fragment = fragment.decode() + _logger.info("Command fragment in %s: %s", mockname, fragment) + + if fragment.startswith("*") and command is None: + command = [None for _ in range(int(fragment[1:]))] + command_ptr = 0 + fragment_length = None + continue + + if fragment.startswith("$") and command[command_ptr] is None: + fragment_length = int(fragment[1:]) + continue + + assert len(fragment) == fragment_length + command[command_ptr] = fragment + command_ptr += 1 + + if command_ptr < len(command): + continue + + command = " ".join(command) + _logger.info("Command in %s: %s", mockname, command) + resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP) + _logger.info("Response from %s: %s", mockname, resp) + conn.sendall(resp) + command = None