From dcb43be8916865e066643a3b8fd1bd6fb767601a Mon Sep 17 00:00:00 2001 From: awmackowiak Date: Wed, 12 Jun 2024 18:00:22 +0200 Subject: [PATCH] Fix Redis connections after reconnect - consumer starts consuming the tasks after crash. (#2007) * Add more logs * Launch _on_connection_disconnect in Conection only if channel was added properly to the poller * Prepare test which check the flow of the channel removal from poller * Change the comment --- kombu/transport/redis.py | 10 ++- t/unit/transport/test_redis.py | 156 +++++++++++++++++++++++++++++++++ 2 files changed, 164 insertions(+), 2 deletions(-) diff --git a/kombu/transport/redis.py b/kombu/transport/redis.py index 89cabda68..9311ecf5c 100644 --- a/kombu/transport/redis.py +++ b/kombu/transport/redis.py @@ -722,7 +722,7 @@ def __init__(self, *args, **kwargs): if not self.ack_emulation: # disable visibility timeout self.QoS = virtual.QoS - + self._registered = False self._queue_cycle = cycle_by_name(self.queue_order_strategy)() self.Client = self._get_client() self.ResponseError = self._get_response_error() @@ -747,6 +747,9 @@ def __init__(self, *args, **kwargs): raise self.connection.cycle.add(self) # add to channel poller. + # and set to true after sucessfuly added channel to the poll. + self._registered = True + # copy errors, in case channel closed but threads still # are still waiting for data. self.connection_errors = self.connection.connection_errors @@ -1201,7 +1204,10 @@ def _connparams(self, asynchronous=False): class Connection(connection_cls): def disconnect(self, *args): super().disconnect(*args) - channel._on_connection_disconnect(self) + # We remove the connection from the poller + # only if it has been added properly. + if channel._registered: + channel._on_connection_disconnect(self) connection_cls = Connection connparams['connection_class'] = connection_cls diff --git a/t/unit/transport/test_redis.py b/t/unit/transport/test_redis.py index 1bc81e0e2..a2c015ec2 100644 --- a/t/unit/transport/test_redis.py +++ b/t/unit/transport/test_redis.py @@ -346,17 +346,173 @@ class XTransport(Transport): Channel = XChannel conn = Connection(transport=XTransport) + conn.transport.cycle = Mock(name='cycle') client.ping.side_effect = RuntimeError() with pytest.raises(RuntimeError): conn.channel() pool.disconnect.assert_called_with() pool.disconnect.reset_mock() + # Ensure that the channel without ensured connection to Redis + # won't be added to the cycle. + conn.transport.cycle.add.assert_not_called() + assert len(conn.transport.channels) == 0 pool_at_init = [None] with pytest.raises(RuntimeError): conn.channel() pool.disconnect.assert_not_called() + def test_redis_connection_added_to_cycle_if_ping_succeeds(self): + """Test should check the connection is added to the cycle only + if the ping to Redis was finished successfully.""" + # given: mock pool and client + pool = Mock(name='pool') + client = Mock(name='client') + + # override channel class with given mocks + class XChannel(Channel): + def __init__(self, *args, **kwargs): + self._pool = pool + super().__init__(*args, **kwargs) + + def _get_client(self): + return lambda *_, **__: client + + # override Channel in Transport with given channel + class XTransport(Transport): + Channel = XChannel + + # when: create connection with overridden transport + conn = Connection(transport=XTransport) + conn.transport.cycle = Mock(name='cycle') + # create the channel + chan = conn.channel() + # then: check if ping was called + client.ping.assert_called_once() + # the connection was added to the cycle + conn.transport.cycle.add.assert_called_once() + assert len(conn.transport.channels) == 1 + # the channel was flaged as registered into poller + assert chan._registered + + def test_redis_on_disconnect_channel_only_if_was_registered(self): + """Test shoud check if the _on_disconnect method is called only + if the channel was registered into the poller.""" + # given: mock pool and client + pool = Mock(name='pool') + client = Mock( + name='client', + ping=Mock(return_value=True) + ) + + # create RedisConnectionMock class + # for the possibility to run disconnect method + class RedisConnectionMock: + def disconnect(self, *args): + pass + + # override Channel method with given mocks + class XChannel(Channel): + connection_class = RedisConnectionMock + + def __init__(self, *args, **kwargs): + self._pool = pool + # counter to check if the method was called + self.on_disconect_count = 0 + super().__init__(*args, **kwargs) + + def _get_client(self): + return lambda *_, **__: client + + def _on_connection_disconnect(self, connection): + # increment the counter when the method is called + self.on_disconect_count += 1 + + # create the channel + chan = XChannel(Mock( + _used_channel_ids=[], + channel_max=1, + channels=[], + client=Mock( + transport_options={}, + hostname="127.0.0.1", + virtual_host=None))) + # create the _connparams with overriden connection_class + connparams = chan._connparams(asynchronous=True) + # create redis.Connection + redis_connection = connparams['connection_class']() + # the connection was added to the cycle + chan.connection.cycle.add.assert_called_once() + # and the ping was called + client.ping.assert_called_once() + # the channel was registered + assert chan._registered + # than disconnect the Redis connection + redis_connection.disconnect() + # the on_disconnect counter should be incremented + assert chan.on_disconect_count == 1 + + def test_redis__on_disconnect_should_not_be_called_if_not_registered(self): + """Test should check if the _on_disconnect method is not called because + the connection to Redis isn't established properly.""" + # given: mock pool + pool = Mock(name='pool') + # client mock with ping method which return ConnectionError + from redis.exceptions import ConnectionError + client = Mock( + name='client', + ping=Mock(side_effect=ConnectionError()) + ) + + # create RedisConnectionMock + # for the possibility to run disconnect method + class RedisConnectionMock: + def disconnect(self, *args): + pass + + # override Channel method with given mocks + class XChannel(Channel): + connection_class = RedisConnectionMock + + def __init__(self, *args, **kwargs): + self._pool = pool + # counter to check if the method was called + self.on_disconect_count = 0 + super().__init__(*args, **kwargs) + + def _get_client(self): + return lambda *_, **__: client + + def _on_connection_disconnect(self, connection): + # increment the counter when the method is called + self.on_disconect_count += 1 + + # then: exception was risen + with pytest.raises(ConnectionError): + # when: create the channel + chan = XChannel(Mock( + _used_channel_ids=[], + channel_max=1, + channels=[], + client=Mock( + transport_options={}, + hostname="127.0.0.1", + virtual_host=None))) + # create the _connparams with overriden connection_class + connparams = chan._connparams(asynchronous=True) + # create redis.Connection + redis_connection = connparams['connection_class']() + # the connection wasn't added to the cycle + chan.connection.cycle.add.assert_not_called() + # the ping was called once with the exception + client.ping.assert_called_once() + # the channel was not registered + assert not chan._registered + # then: disconnect the Redis connection + redis_connection.disconnect() + # the on_disconnect counter shouldn't be incremented + assert chan.on_disconect_count == 0 + def test_get_redis_ConnectionError(self): from redis.exceptions import ConnectionError