From a1d5a9bb1df84fd0eed1a1397bcf8c4869eb84b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Mon, 12 Dec 2022 22:34:35 +0000 Subject: [PATCH] Add regression tests and fixes for issue #1128 --- redis/asyncio/client.py | 4 ++- redis/asyncio/connection.py | 28 +++++++++++++-------- redis/client.py | 2 +- redis/connection.py | 24 +++++++++++++----- tests/test_asyncio/test_commands.py | 38 +++++++++++++++++++++++++++++ tests/test_commands.py | 35 ++++++++++++++++++++++++++ 6 files changed, 113 insertions(+), 18 deletions(-) diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 3e6626aedf..cb0dcb1731 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -816,7 +816,9 @@ async def parse_response(self, block: bool = True, timeout: float = 0): await conn.connect() read_timeout = None if block else timeout - response = await self._execute(conn, conn.read_response, timeout=read_timeout) + response = await self._execute( + conn, conn.read_response, timeout=read_timeout, disconnect_on_error=False + ) if conn.health_check_interval and response == self.health_check_response: # ignore the health check message as user might not expect it diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 59f75aa229..f430d0a1c6 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -804,7 +804,11 @@ async def send_packed_command( raise ConnectionError( f"Error {err_no} while writing to socket. {errmsg}." ) from e - except Exception: + except BaseException: + # BaseExceptions can be raised when a socket send operation is not + # finished, e.g. due to a timeout. Ideally, a caller could then re-try + # to send un-sent data. However, the send_packed_command() API + # does not support it so there is no point in keeping the connection open. await self.disconnect(nowait=True) raise @@ -827,7 +831,9 @@ async def can_read_destructive(self): async def read_response( self, disable_decoding: bool = False, + *, timeout: Optional[float] = None, + disconnect_on_error: bool = True, ): """Read the response from a previously sent command""" read_timeout = timeout if timeout is not None else self.socket_timeout @@ -843,22 +849,24 @@ async def read_response( ) except asyncio.TimeoutError: if timeout is not None: - # user requested timeout, return None + # user requested timeout, return None. Operation can be retried return None # it was a self.socket_timeout error. - await self.disconnect(nowait=True) + if disconnect_on_error: + await self.disconnect(nowait=True) raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") except OSError as e: - await self.disconnect(nowait=True) + if disconnect_on_error: + await self.disconnect(nowait=True) raise ConnectionError( f"Error while reading from {self.host}:{self.port} : {e.args}" ) - except asyncio.CancelledError: - # need this check for 3.7, where CancelledError - # is subclass of Exception, not BaseException - raise - except Exception: - await self.disconnect(nowait=True) + except BaseException: + # Also by default close in case of BaseException. A lot of code + # relies on this behaviour when doing Command/Response pairs. + # See #1128. + if disconnect_on_error: + await self.disconnect(nowait=True) raise if self.health_check_interval: diff --git a/redis/client.py b/redis/client.py index 79a7bff2a2..ea27dea715 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1529,7 +1529,7 @@ def try_read(): return None else: conn.connect() - return conn.read_response() + return conn.read_response(disconnect_on_error=False) response = self._execute(conn, try_read) diff --git a/redis/connection.py b/redis/connection.py index 8b2389c6db..5af8928a5d 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -834,7 +834,11 @@ def send_packed_command(self, command, check_health=True): errno = e.args[0] errmsg = e.args[1] raise ConnectionError(f"Error {errno} while writing to socket. {errmsg}.") - except Exception: + except BaseException: + # BaseExceptions can be raised when a socket send operation is not + # finished, e.g. due to a timeout. Ideally, a caller could then re-try + # to send un-sent data. However, the send_packed_command() API + # does not support it so there is no point in keeping the connection open. self.disconnect() raise @@ -859,7 +863,9 @@ def can_read(self, timeout=0): self.disconnect() raise ConnectionError(f"Error while reading from {host_error}: {e.args}") - def read_response(self, disable_decoding=False): + def read_response( + self, disable_decoding=False, *, disconnect_on_error: bool = True + ): """Read the response from a previously sent command""" host_error = self._host_error() @@ -867,15 +873,21 @@ def read_response(self, disable_decoding=False): try: response = self._parser.read_response(disable_decoding=disable_decoding) except socket.timeout: - self.disconnect() + if disconnect_on_error: + self.disconnect() raise TimeoutError(f"Timeout reading from {host_error}") except OSError as e: - self.disconnect() + if disconnect_on_error: + self.disconnect() raise ConnectionError( f"Error while reading from {host_error}" f" : {e.args}" ) - except Exception: - self.disconnect() + except BaseException: + # Also by default close in case of BaseException. A lot of code + # relies on this behaviour when doing Command/Response pairs. + # See #1128. + if disconnect_on_error: + self.disconnect() raise if self.health_check_interval: diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index 409934c9a3..a7df03ad91 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -1,9 +1,11 @@ """ Tests async overrides of commands from their mixins """ +import asyncio import binascii import datetime import re +import sys from string import ascii_letters import pytest @@ -18,6 +20,11 @@ skip_unless_arch_bits, ) +if sys.version_info.major >= 3 and sys.version_info.minor >= 11: + from asyncio import timeout as async_timeout +else: + from async_timeout import timeout as async_timeout + REDIS_6_VERSION = "5.9.0" @@ -3008,6 +3015,37 @@ async def test_module_list(self, r: redis.Redis): for x in await r.module_list(): assert isinstance(x, dict) + @pytest.mark.onlynoncluster + async def test_interrupted_command(self, r: redis.Redis): + """ + Regression test for issue #1128: An Un-handled BaseException + will leave the socket with un-read response to a previous + command. + """ + ready = asyncio.Event() + + async def helper(): + with pytest.raises(asyncio.CancelledError): + # blocking pop + ready.set() + await r.brpop(["nonexist"]) + # If the following is not done, further Timout operations will fail, + # because the timeout won't catch its Cancelled Error if the task + # has a pending cancel. Python documentation probably should reflect this. + if sys.version_info.major >= 3 and sys.version_info.minor >= 11: + asyncio.current_task().uncancel() + # if all is well, we can continue. The following should not hang. + await r.set("status", "down") + + task = asyncio.create_task(helper()) + await ready.wait() + await asyncio.sleep(0.01) + # the task is now sleeping, lets send it an exception + task.cancel() + # If all is well, the task should finish right away, otherwise fail with Timeout + async with async_timeout(0.1): + await task + @pytest.mark.onlynoncluster class TestBinarySave: diff --git a/tests/test_commands.py b/tests/test_commands.py index 2b769be34d..f5e88bd095 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -1,9 +1,12 @@ import binascii import datetime import re +import threading import time +from asyncio import CancelledError from string import ascii_letters from unittest import mock +from unittest.mock import patch import pytest @@ -4741,6 +4744,38 @@ def test_psync(self, r): res = r2.psync(r2.client_id(), 1) assert b"FULLRESYNC" in res + @pytest.mark.onlynoncluster + def test_interrupted_command(self, r: redis.Redis): + """ + Regression test for issue #1128: An Un-handled BaseException + will leave the socket with un-read response to a previous + command. + """ + + ok = False + + def helper(): + with pytest.raises(CancelledError): + # blocking pop + with patch.object( + r.connection._parser, "read_response", side_effect=CancelledError + ): + r.brpop(["nonexist"]) + # if all is well, we can continue. + r.set("status", "down") # should not hang + nonlocal ok + ok = True + + thread = threading.Thread(target=helper) + thread.start() + thread.join(0.1) + try: + assert not thread.is_alive() + assert ok + finally: + # disconnect here so that fixture cleanup can proceed + r.connection.disconnect() + @pytest.mark.onlynoncluster class TestBinarySave: