Skip to content

Commit

Permalink
Support Select Connections in emit_mq_message (#115)
Browse files Browse the repository at this point in the history
* Update `emit_mq_message` to support SelectConnections
Related to neon-iris improvements

* Update `emit_mq_message` to support SelectConnection channel creation callback

* Loosen dependency to allow newer `pika`

* Troubleshoot channel opening

* Add more logging

* Add more logging to troubleshoot blocking connection usage

* Cleanup logging
Troubleshoot blocking connection usage

* Update `emit_mq_message` to prevent mutating input data
Add test coverage for `emit_mq_message` for Blocking and Select connections

* Update log to DEBUG
Disable sync and observer threads in test Connector instance

* Update test to wait for open event
  • Loading branch information
NeonDaniel authored Jan 23, 2025
1 parent a9fba67 commit cfc073f
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 30 deletions.
71 changes: 43 additions & 28 deletions neon_mq_connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import copy
import time
import uuid
from asyncio import Event

import pika
import pika.exceptions

Expand All @@ -46,7 +48,6 @@
from neon_mq_connector.utils.network_utils import dict_to_b64
from neon_mq_connector.utils.thread_utils import RepeatingTimer


# DO NOT REMOVE ME: Defined for backward compatibility
ConsumerThread = BlockingConsumerThread

Expand Down Expand Up @@ -271,7 +272,8 @@ def create_unique_id():

@classmethod
def emit_mq_message(cls,
connection: pika.BlockingConnection,
connection: Union[pika.BlockingConnection,
pika.SelectConnection],
request_data: dict,
exchange: Optional[str] = '',
queue: Optional[str] = '',
Expand All @@ -292,6 +294,9 @@ def emit_mq_message(cls,
:raises ValueError: invalid request data provided
:returns message_id: id of the sent message
"""
# Make a copy of request_data to prevent modifying the input object
request_data = dict(request_data)

if not isinstance(request_data, dict):
raise TypeError(f"Expected dict and got {type(request_data)}")
if not request_data:
Expand All @@ -302,22 +307,32 @@ def emit_mq_message(cls,
.get("mq", {}).get("message_id") or
cls.create_unique_id())

with connection.channel() as channel:
def _on_channel_open(new_channel):
if exchange:
channel.exchange_declare(exchange=exchange,
exchange_type=exchange_type,
auto_delete=False)
new_channel.exchange_declare(exchange=exchange,
exchange_type=exchange_type,
auto_delete=False)
if queue:
declared_queue = channel.queue_declare(queue=queue,
auto_delete=False)
declared_queue = new_channel.queue_declare(queue=queue,
auto_delete=False)
if exchange_type == ExchangeType.fanout.value:
channel.queue_bind(queue=declared_queue.method.queue,
exchange=exchange)
channel.basic_publish(exchange=exchange or '',
routing_key=queue,
body=dict_to_b64(request_data),
properties=pika.BasicProperties(
expiration=str(expiration)))
new_channel.queue_bind(queue=declared_queue.method.queue,
exchange=exchange)
new_channel.basic_publish(exchange=exchange or '',
routing_key=queue,
body=dict_to_b64(request_data),
properties=pika.BasicProperties(
expiration=str(expiration)))

new_channel.close()

if isinstance(connection, pika.BlockingConnection):
LOG.debug(f"Using blocking connection for request: {request_data}")
_on_channel_open(connection.channel())
else:
LOG.debug(f"Using select connection for queue: {queue}")
connection.channel(on_open_callback=_on_channel_open)

LOG.debug(f"sent message: {request_data['message_id']}")
return request_data['message_id']

Expand Down Expand Up @@ -448,17 +463,17 @@ def register_consumer(self, name: str, vhost: str, queue: str,
self.consumer_properties.setdefault(name, {})
self.consumer_properties[name]['properties'] = \
dict(
name=name,
connection_params=self.get_connection_params(vhost),
queue=queue,
queue_reset=queue_reset,
callback_func=callback,
exchange=exchange,
exchange_reset=exchange_reset,
exchange_type=exchange_type,
error_func=error_handler,
auto_ack=auto_ack,
queue_exclusive=queue_exclusive,
name=name,
connection_params=self.get_connection_params(vhost),
queue=queue,
queue_reset=queue_reset,
callback_func=callback,
exchange=exchange,
exchange_reset=exchange_reset,
exchange_type=exchange_type,
error_func=error_handler,
auto_ack=auto_ack,
queue_exclusive=queue_exclusive,
)
self.consumer_properties[name]['restart_attempts'] = int(restart_attempts)
self.consumer_properties[name]['started'] = False
Expand Down Expand Up @@ -556,8 +571,8 @@ def run_consumers(self, names: Optional[tuple] = None, daemon=True):
names = list(self.consumers)
for name in names:
if (isinstance(self.consumers.get(name), SUPPORTED_THREADED_CONSUMERS)
and self.consumers[name].is_consumer_alive
and not self.consumers[name].is_consuming):
and self.consumers[name].is_consumer_alive
and not self.consumers[name].is_consuming):
self.consumers[name].daemon = daemon
self.consumers[name].start()
self.consumer_properties[name]['started'] = True
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pika==1.2.0
pika~=1.2
ovos-config~=0.0,>=0.0.8
ovos-utils~=0.0,>=0.0.32
82 changes: 81 additions & 1 deletion tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@

from unittest.mock import Mock, patch
from ovos_utils.log import LOG
from pika.adapters.blocking_connection import BlockingConnection
from pika.adapters.select_connection import SelectConnection
from pika.exchange_type import ExchangeType

from neon_mq_connector.connector import MQConnector, ConsumerThreadInstance
Expand Down Expand Up @@ -384,4 +386,82 @@ def test_init_rmq_down(self, get_timeout):
callback.assert_called_once()
connector.stop()

# TODO: test other methods
def test_emit_mq_message(self):
from neon_mq_connector.utils.network_utils import b64_to_dict

test_config = {"server": "127.0.0.1",
"port": self.rmq_instance.port,
"users": {
"test": {
"user": "test_user",
"password": "test_password"
}}}
test_vhost = "/neon_testing"
test_queue = "test_queue"
connector = MQConnector(test_config, "test")
connector.vhost = test_vhost

request_data = {"test": True,
"data": ["test"]}

callback_event = threading.Event()
callback = Mock(side_effect=lambda *args: callback_event.set())
connector.register_consumer("test_consumer", vhost=test_vhost,
queue=test_queue, callback=callback)
connector.run(run_sync=False, run_observer=False)

open_event = threading.Event()
close_event = threading.Event()
on_open = Mock(side_effect=lambda *args: open_event.set())
on_error = Mock()
on_close = Mock(side_effect=lambda *args: close_event.set())

blocking_connection = BlockingConnection(
parameters=connector.get_connection_params(test_vhost))

async_connection = SelectConnection(
parameters=connector.get_connection_params(test_vhost),
on_open_callback=on_open, on_open_error_callback=on_error,
on_close_callback=on_close)
async_thread = threading.Thread(target=async_connection.ioloop.start,
daemon=True)
async_thread.start()

# Blocking connection emit
message_id = connector.emit_mq_message(blocking_connection,
request_data, queue=test_queue)
self.assertIsInstance(message_id, str)
callback_event.wait(timeout=5)
self.assertTrue(callback_event.is_set())
callback.assert_called_once()
self.assertEqual(b64_to_dict(callback.call_args.args[3]),
{**request_data, "message_id": message_id})
callback.reset_mock()
callback_event.clear()

# Async connection emit
open_event.wait(timeout=5)
self.assertTrue(open_event.is_set())
on_open.assert_called_once()
message_id_2 = connector.emit_mq_message(async_connection,
request_data, queue=test_queue)
self.assertIsInstance(message_id, str)
self.assertNotEqual(message_id, message_id_2)
callback_event.wait(timeout=5)
self.assertTrue(callback_event.is_set())
callback.assert_called_once()
self.assertEqual(b64_to_dict(callback.call_args.args[3]),
{**request_data, "message_id": message_id_2})

on_close.assert_not_called()
connector.stop()

async_connection.close()
close_event.wait(timeout=5)
self.assertTrue(close_event.is_set())
on_close.assert_called_once()

async_thread.join(3)
on_error.assert_not_called()

# TODO: test other methods

0 comments on commit cfc073f

Please sign in to comment.