Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Select Connections in emit_mq_message #115

Merged
merged 10 commits into from
Jan 23, 2025
71 changes: 43 additions & 28 deletions neon_mq_connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import copy
import time
import uuid
from asyncio import Event

import pika
import pika.exceptions

Expand All @@ -46,7 +48,6 @@
from neon_mq_connector.utils.network_utils import dict_to_b64
from neon_mq_connector.utils.thread_utils import RepeatingTimer


# DO NOT REMOVE ME: Defined for backward compatibility
ConsumerThread = BlockingConsumerThread

Expand Down Expand Up @@ -271,7 +272,8 @@ def create_unique_id():

@classmethod
def emit_mq_message(cls,
connection: pika.BlockingConnection,
connection: Union[pika.BlockingConnection,
pika.SelectConnection],
request_data: dict,
exchange: Optional[str] = '',
queue: Optional[str] = '',
Expand All @@ -292,6 +294,9 @@ def emit_mq_message(cls,
:raises ValueError: invalid request data provided
:returns message_id: id of the sent message
"""
# Make a copy of request_data to prevent modifying the input object
request_data = dict(request_data)

if not isinstance(request_data, dict):
raise TypeError(f"Expected dict and got {type(request_data)}")
if not request_data:
Expand All @@ -302,22 +307,32 @@ def emit_mq_message(cls,
.get("mq", {}).get("message_id") or
cls.create_unique_id())

with connection.channel() as channel:
def _on_channel_open(new_channel):
if exchange:
channel.exchange_declare(exchange=exchange,
exchange_type=exchange_type,
auto_delete=False)
new_channel.exchange_declare(exchange=exchange,
exchange_type=exchange_type,
auto_delete=False)
if queue:
declared_queue = channel.queue_declare(queue=queue,
auto_delete=False)
declared_queue = new_channel.queue_declare(queue=queue,
auto_delete=False)
if exchange_type == ExchangeType.fanout.value:
channel.queue_bind(queue=declared_queue.method.queue,
exchange=exchange)
channel.basic_publish(exchange=exchange or '',
routing_key=queue,
body=dict_to_b64(request_data),
properties=pika.BasicProperties(
expiration=str(expiration)))
new_channel.queue_bind(queue=declared_queue.method.queue,
exchange=exchange)
new_channel.basic_publish(exchange=exchange or '',
routing_key=queue,
body=dict_to_b64(request_data),
properties=pika.BasicProperties(
expiration=str(expiration)))

new_channel.close()

if isinstance(connection, pika.BlockingConnection):
LOG.debug(f"Using blocking connection for request: {request_data}")
_on_channel_open(connection.channel())
else:
LOG.debug(f"Using select connection for queue: {queue}")
connection.channel(on_open_callback=_on_channel_open)

LOG.debug(f"sent message: {request_data['message_id']}")
return request_data['message_id']

Expand Down Expand Up @@ -448,17 +463,17 @@ def register_consumer(self, name: str, vhost: str, queue: str,
self.consumer_properties.setdefault(name, {})
self.consumer_properties[name]['properties'] = \
dict(
name=name,
connection_params=self.get_connection_params(vhost),
queue=queue,
queue_reset=queue_reset,
callback_func=callback,
exchange=exchange,
exchange_reset=exchange_reset,
exchange_type=exchange_type,
error_func=error_handler,
auto_ack=auto_ack,
queue_exclusive=queue_exclusive,
name=name,
connection_params=self.get_connection_params(vhost),
queue=queue,
queue_reset=queue_reset,
callback_func=callback,
exchange=exchange,
exchange_reset=exchange_reset,
exchange_type=exchange_type,
error_func=error_handler,
auto_ack=auto_ack,
queue_exclusive=queue_exclusive,
)
self.consumer_properties[name]['restart_attempts'] = int(restart_attempts)
self.consumer_properties[name]['started'] = False
Expand Down Expand Up @@ -556,8 +571,8 @@ def run_consumers(self, names: Optional[tuple] = None, daemon=True):
names = list(self.consumers)
for name in names:
if (isinstance(self.consumers.get(name), SUPPORTED_THREADED_CONSUMERS)
and self.consumers[name].is_consumer_alive
and not self.consumers[name].is_consuming):
and self.consumers[name].is_consumer_alive
and not self.consumers[name].is_consuming):
self.consumers[name].daemon = daemon
self.consumers[name].start()
self.consumer_properties[name]['started'] = True
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pika==1.2.0
pika~=1.2
ovos-config~=0.0,>=0.0.8
ovos-utils~=0.0,>=0.0.32
82 changes: 81 additions & 1 deletion tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@

from unittest.mock import Mock, patch
from ovos_utils.log import LOG
from pika.adapters.blocking_connection import BlockingConnection
from pika.adapters.select_connection import SelectConnection
from pika.exchange_type import ExchangeType

from neon_mq_connector.connector import MQConnector, ConsumerThreadInstance
Expand Down Expand Up @@ -384,4 +386,82 @@ def test_init_rmq_down(self, get_timeout):
callback.assert_called_once()
connector.stop()

# TODO: test other methods
def test_emit_mq_message(self):
from neon_mq_connector.utils.network_utils import b64_to_dict

test_config = {"server": "127.0.0.1",
"port": self.rmq_instance.port,
"users": {
"test": {
"user": "test_user",
"password": "test_password"
}}}
test_vhost = "/neon_testing"
test_queue = "test_queue"
connector = MQConnector(test_config, "test")
connector.vhost = test_vhost

request_data = {"test": True,
"data": ["test"]}

callback_event = threading.Event()
callback = Mock(side_effect=lambda *args: callback_event.set())
connector.register_consumer("test_consumer", vhost=test_vhost,
queue=test_queue, callback=callback)
connector.run(run_sync=False, run_observer=False)

open_event = threading.Event()
close_event = threading.Event()
on_open = Mock(side_effect=lambda *args: open_event.set())
on_error = Mock()
on_close = Mock(side_effect=lambda *args: close_event.set())

blocking_connection = BlockingConnection(
parameters=connector.get_connection_params(test_vhost))

async_connection = SelectConnection(
parameters=connector.get_connection_params(test_vhost),
on_open_callback=on_open, on_open_error_callback=on_error,
on_close_callback=on_close)
async_thread = threading.Thread(target=async_connection.ioloop.start,
daemon=True)
async_thread.start()

# Blocking connection emit
message_id = connector.emit_mq_message(blocking_connection,
request_data, queue=test_queue)
self.assertIsInstance(message_id, str)
callback_event.wait(timeout=5)
self.assertTrue(callback_event.is_set())
callback.assert_called_once()
self.assertEqual(b64_to_dict(callback.call_args.args[3]),
{**request_data, "message_id": message_id})
callback.reset_mock()
callback_event.clear()

# Async connection emit
open_event.wait(timeout=5)
self.assertTrue(open_event.is_set())
on_open.assert_called_once()
message_id_2 = connector.emit_mq_message(async_connection,
request_data, queue=test_queue)
self.assertIsInstance(message_id, str)
self.assertNotEqual(message_id, message_id_2)
callback_event.wait(timeout=5)
self.assertTrue(callback_event.is_set())
callback.assert_called_once()
self.assertEqual(b64_to_dict(callback.call_args.args[3]),
{**request_data, "message_id": message_id_2})

on_close.assert_not_called()
connector.stop()

async_connection.close()
close_event.wait(timeout=5)
self.assertTrue(close_event.is_set())
on_close.assert_called_once()

async_thread.join(3)
on_error.assert_not_called()

# TODO: test other methods
Loading