Skip to content

Commit

Permalink
hive-messaging: Long-running publisher support
Browse files Browse the repository at this point in the history
  • Loading branch information
gbenson committed Oct 7, 2024
1 parent eee7250 commit d0375db
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 8 deletions.
1 change: 1 addition & 0 deletions libs/messaging/hive/messaging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
DEFAULT_MESSAGE_BUS = MessageBus()

blocking_connection = DEFAULT_MESSAGE_BUS.blocking_connection
publisher_connection = DEFAULT_MESSAGE_BUS.publisher_connection
send_to_queue = DEFAULT_MESSAGE_BUS.send_to_queue
tell_user = DEFAULT_MESSAGE_BUS.tell_user

Expand Down
21 changes: 21 additions & 0 deletions libs/messaging/hive/messaging/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,24 @@ def _wrapped_callback(channel, *args, **kwargs):
*args,
**kwargs
)


class PublisherChannel:
def __init__(self, invoker, channel):
self._invoker = invoker
self._channel = channel

def __getattr__(self, attr):
result = getattr(self._channel, attr)
if not callable(result):
return result
return PublisherInvoker(self._invoker, result)


class PublisherInvoker:
def __init__(self, invoker, func):
self._invoke = invoker
self._func = func

def __call__(self, *args, **kwargs):
return self._invoke(self._func, *args, **kwargs)
89 changes: 86 additions & 3 deletions libs/messaging/hive/messaging/connection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
from .channel import Channel
import logging

from threading import Event, Thread

from hive.common.units import SECOND

from .channel import Channel, PublisherChannel
from .wrapper import WrappedPikaThing

logger = logging.getLogger(__name__)
d = logger.debug


class Connection(WrappedPikaThing):
def __init__(self, *args, **kwargs):
Expand All @@ -14,17 +23,91 @@ def __exit__(self, *exc_info):
if self._pika.is_open:
self._pika.close()

def channel(self, *args, **kwargs):
def _channel(self, *args, **kwargs) -> Channel:
return Channel(self._pika.channel(*args, **kwargs))

def channel(self, *args, **kwargs) -> Channel:
"""Like :class:pika.channel.Channel` but with different defaults.
:param confirm_delivery: Whether to enable delivery confirmations.
Hive's default is True. Use `confirm_delivery=False` for the
original Pika behaviour.
"""
confirm_delivery = kwargs.pop("confirm_delivery", True)
channel = Channel(self._pika.channel(*args, **kwargs))
channel = self._channel(*args, **kwargs)
if confirm_delivery:
channel.confirm_delivery() # Don't fail silently.
if self.on_channel_open:
self.on_channel_open(channel)
return channel


class PublisherConnection(Connection, Thread):
def __init__(self, *args, **kwargs):
thread_name = kwargs.pop("thread_name", "Publisher")
Thread.__init__(self, name=thread_name, daemon=True)
Connection.__init__(self, *args, **kwargs)
self.is_running = True

def __enter__(self):
logger.info("Starting publisher thread")
Thread.start(self)
return Connection.__enter__(self)

def run(self):
logger.info("%s: thread started", self.name)
while self.is_running:
self.process_data_events(time_limit=1 * SECOND)
logger.info("%s: thread stopping", self.name)
self.process_data_events(time_limit=1 * SECOND)
logger.info("%s: thread stopped", self.name)

def __exit__(self, *exc_info):
logger.info("Stopping publisher thread")
self.is_running = False
self.join()
logger.info("Publisher thread stopped")
return Connection.__exit__(self, *exc_info)

def _channel(self, *args, **kwargs) -> Channel:
return PublisherChannel(
self._invoke,
self._invoke(super()._channel, *args, **kwargs),
)

def _invoke(self, func, *args, **kwargs):
callback = PublisherCallback(func, args, kwargs)
self.add_callback_threadsafe(callback)
return callback.join()


class PublisherCallback:
def __init__(self, func, args, kwargs):
self._func = func
self._args = args
self._kwargs = kwargs
self._event = Event()
self._result = None
self._exception = None

def __call__(self):
d("Entering callback")
try:
self._result = self._func(*self._args, **self._kwargs)
except Exception as exc:
self._exception = exc
finally:
self._event.set()
del self._func, self._args, self._kwargs
d("Leaving callback")

def join(self, *args, **kwargs):
d("Waiting for callback")
self._event.wait(*args, **kwargs)
d("Callback returned")
try:
if self._exception:
raise self._exception
return self._result
finally:
del self._result, self._exception
21 changes: 16 additions & 5 deletions libs/messaging/hive/messaging/message_bus.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

from dataclasses import dataclass, field
from typing import Optional
from typing import Callable, Optional

from pika import (
BlockingConnection,
Expand All @@ -14,7 +14,8 @@

from hive.config import read as read_config

from .connection import Connection
from .channel import Channel
from .connection import Connection, PublisherConnection


@dataclass
Expand Down Expand Up @@ -70,11 +71,16 @@ def queue_connection_parameters(
**kwargs
)

def blocking_connection(self, **kwargs) -> Connection:
on_channel_open = kwargs.pop("on_channel_open", None)
def blocking_connection(
self,
*,
connection_class: type[Connection] = Connection,
on_channel_open: Optional[Callable[[Channel], None]] = None,
**kwargs
) -> Connection:
params = self.queue_connection_parameters(**kwargs)
try:
return Connection(
return connection_class(
BlockingConnection(params),
on_channel_open=on_channel_open,
)
Expand All @@ -85,6 +91,11 @@ def blocking_connection(self, **kwargs) -> Connection:
raise e
raise

def publisher_connection(self, **kwargs) -> Connection:
return self.blocking_connection(
connection_class=PublisherConnection,
**kwargs)

def send_to_queue(self, queue: str, *args, **kwargs):
durable = kwargs.pop("durable", True)
with self.blocking_connection(connection_attempts=1) as conn:
Expand Down

0 comments on commit d0375db

Please sign in to comment.