From 50123cc7ea2ff61b383b61359e0a3b203213df25 Mon Sep 17 00:00:00 2001 From: woutdenolf Date: Wed, 22 Mar 2023 14:51:56 +0100 Subject: [PATCH] parse Redis commands in the mock server and shutdown server on failure --- tests/test_connect.py | 188 +++++++++++++++++++++++++++--------------- 1 file changed, 123 insertions(+), 65 deletions(-) diff --git a/tests/test_connect.py b/tests/test_connect.py index fd7b24fc7e..843665770b 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -1,4 +1,5 @@ import logging +import re import socket import ssl import threading @@ -12,6 +13,43 @@ _CLIENT_NAME = "test-suite-client" +_CMD_SEP = b"\r\n" +_SUCCESS_RESP = b"+OK" + _CMD_SEP +_ERROR_RESP = b"-ERR" + _CMD_SEP +_COMMANDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP} + + +@pytest.fixture +def tcp_address(): + with socket.socket() as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname() + + +@pytest.fixture +def uds_address(tmpdir): + return tmpdir / "uds.sock" + + +@pytest.fixture +def ssl_cert(tcp_address, tmpdir): + """More or less equivalent to + + .. code:: + + openssl req -new -x509 -days 365 -nodes -out mycert.pem -keyout mycert.pem + """ + host, _ = tcp_address + ca = trustme.CA() + cert = ca.issue_cert(host, common_name="trustme") + + server_pem = str(tmpdir / "server.pem") + cert.private_key_and_cert_chain_pem.write_to_path(path=server_pem) + + client_pem = str(tmpdir / "client.pem") + ca.cert_pem.write_to_path(path=client_pem) + + return client_pem, server_pem def test_tcp_connect(tcp_address): @@ -35,7 +73,25 @@ def test_tcp_ssl_connect(tcp_address, ssl_cert): _assert_connect(conn, tcp_address, certfile=server_pem) -def redis_mock_server(server_address, ready, commands, certfile=None): +def _assert_connect(conn, server_address, certfile=None): + ready = threading.Event() + stop = threading.Event() + t = threading.Thread( + target=_redis_mock_server, + args=(server_address, ready, stop), + kwargs={"certfile": certfile}, + ) + t.start() + try: + ready.wait() + conn.connect() + conn.disconnect() + finally: + stop.set() + t.join(timeout=5) + + +def _redis_mock_server(server_address, ready, stop, certfile=None): try: if isinstance(server_address, str): family = socket.AF_UNIX @@ -46,86 +102,88 @@ def redis_mock_server(server_address, ready, commands, certfile=None): else: family = socket.AF_INET mockname = "Redis mock server (TCP)" + with socket.socket(family, socket.SOCK_STREAM) as s: s.bind(server_address) s.listen(1) + s.settimeout(0.1) if certfile: context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + context.minimum_version = ssl.TLSVersion.TLSv1_2 context.load_cert_chain(certfile=certfile) _logger.info("Start %s: %s", mockname, server_address) ready.set() - ssock, _ = s.accept() - with ssock: + + # Wait a client connection + while not stop.is_set(): + try: + sconn, _ = s.accept() + sconn.settimeout(0.1) + break + except socket.timeout: + pass + if stop.is_set(): + _logger.info("Exit %s: %s", mockname, server_address) + return + + # Receive commands from the client + with sconn: if certfile: - conn = context.wrap_socket(ssock, server_side=True) + conn = context.wrap_socket(sconn, server_side=True) else: - conn = ssock + conn = sconn try: - while True: - data = conn.recv(1024) - if not data: - _logger.info("Exit %s: %s", mockname, server_address) - break - _logger.info("Command in %s: %s", mockname, data) - resp = b"+ERROR\r\n" - resp = commands.get(data, resp) - _logger.info("Response from %s: %s", mockname, resp) - conn.sendall(resp) + buffer = b"" + command = None + command_ptr = None + fragment_length = None + while not stop.is_set() or buffer: + try: + buffer += conn.recv(1024) + except socket.timeout: + continue + if not buffer: + continue + parts = re.split(_CMD_SEP, buffer) + buffer = parts[-1] + for fragment in parts[:-1]: + fragment = fragment.decode() + _logger.info( + "Command fragment in %s: %s", mockname, fragment + ) + + if fragment.startswith("*") and command is None: + command = [None for _ in range(int(fragment[1:]))] + command_ptr = 0 + fragment_length = None + continue + + if ( + fragment.startswith("$") + and command[command_ptr] is None + ): + fragment_length = int(fragment[1:]) + continue + + assert len(fragment) == fragment_length + command[command_ptr] = fragment + command_ptr += 1 + + if command_ptr < len(command): + continue + + command = " ".join(command) + _logger.info("Command in %s: %s", mockname, command) + resp = _COMMANDS.get(command, _ERROR_RESP) + _logger.info("Response from %s: %s", mockname, resp) + conn.sendall(resp) + command = None finally: if certfile: conn.close() + _logger.info("Exit %s: %s", mockname, server_address) except BaseException as e: _logger.exception("Error in %s: %s", mockname, e) raise - - -def _assert_connect(conn, server_address, **server_kwargs): - command = conn.pack_command("CLIENT", "SETNAME", _CLIENT_NAME)[0] - commands = {command: b"+OK\r\n"} - - ready = threading.Event() - t = threading.Thread( - target=redis_mock_server, - args=(server_address, ready, commands), - kwargs=server_kwargs, - ) - t.start() - ready.wait() - conn.connect() - conn.disconnect() - t.join() - - -@pytest.fixture -def tcp_address(): - with socket.socket() as sock: - sock.bind(("127.0.0.1", 0)) - return sock.getsockname() - - -@pytest.fixture -def uds_address(tmpdir): - return tmpdir / "uds.sock" - - -@pytest.fixture -def ssl_cert(tcp_address, tmpdir): - """More or less equivalent to - - .. code:: - - openssl req -new -x509 -days 365 -nodes -out mycert.pem -keyout mycert.pem - """ - host, _ = tcp_address - ca = trustme.CA() - cert = ca.issue_cert(host, common_name="trustme") - - server_pem = str(tmpdir / "server.pem") - cert.private_key_and_cert_chain_pem.write_to_path(path=server_pem) - - client_pem = str(tmpdir / "client.pem") - ca.cert_pem.write_to_path(path=client_pem) - - return client_pem, server_pem