Skip to content

Commit

Permalink
dont queue, but report immediately
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits committed Feb 26, 2025
1 parent f914f9d commit b74f272
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
intervention_handlers (List[InterventionHandler], optional): A list of intervention
handlers that can intercept messages before they are sent or published. Defaults to None.
tracer_provider (TracerProvider, optional): The tracer provider to use for tracing. Defaults to None.
ignore_unhandled_exceptions (bool, optional): Whether to ignore unhandled exceptions in that occur in agent event handlers. Any background exceptions will be raised on the next call to `process_next` or from an awaited `stop`, `stop_when_idle` or `stop_when`. Note, this does not apply to RPC handlers. Defaults to True.
Examples:
Expand Down Expand Up @@ -248,7 +249,7 @@ def __init__(
*,
intervention_handlers: List[InterventionHandler] | None = None,
tracer_provider: TracerProvider | None = None,
ignore_unhandled_handler_exceptions: bool = True,
ignore_unhandled_exceptions: bool = True,
) -> None:
self._tracer_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("SingleThreadedAgentRuntime"))
self._message_queue: Queue[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = Queue()
Expand All @@ -262,7 +263,8 @@ def __init__(
self._subscription_manager = SubscriptionManager()
self._run_context: RunContext | None = None
self._serialization_registry = SerializationRegistry()
self._ignore_unhandled_handler_exceptions = ignore_unhandled_handler_exceptions
self._ignore_unhandled_handler_exceptions = ignore_unhandled_exceptions
self._background_exception: BaseException | None = None

@property
def unprocessed_messages_count(
Expand Down Expand Up @@ -523,9 +525,9 @@ async def _on_message(agent: Agent, message_context: MessageContext) -> Any:
responses.append(future)

await asyncio.gather(*responses)
except BaseException:
except BaseException as e:
if not self._ignore_unhandled_handler_exceptions:
raise
self._background_exception = e
finally:
self._message_queue.task_done()
# TODO if responses are given for a publish
Expand Down Expand Up @@ -554,15 +556,28 @@ async def _process_response(self, message_envelope: ResponseMessageEnvelope) ->
self._message_queue.task_done()

async def process_next(self) -> None:
"""Process the next message in the queue."""
"""Process the next message in the queue.
If there is an unhandled exception in the background task, it will be raised here. `process_next` cannot be called again after an unhandled exception is raised.
"""
await self._process_next()

async def _process_next(self) -> None:
"""Process the next message in the queue."""

if self._background_exception is not None:
e = self._background_exception
self._background_exception = None
self._message_queue.shutdown(immediate=True) # type: ignore
raise e

try:
message_envelope = await self._message_queue.get()
except QueueShutDown:
if self._background_exception is not None:
e = self._background_exception
self._background_exception = None
raise e
return

match message_envelope:
Expand Down Expand Up @@ -646,10 +661,7 @@ def handle_process_exception(task: Task[Any]) -> None:
Args:
task: The task that has finished and has potentially raised an exception.
"""
try:
task.result()
finally:
self._background_tasks.discard(task)
self._background_tasks.discard(task)

task = asyncio.create_task(self._process_publish(message_envelope))
self._background_tasks.add(task)
Expand Down Expand Up @@ -725,19 +737,23 @@ async def stop(self) -> None:
if self._run_context is None:
raise RuntimeError("Runtime is not started")

await self._run_context.stop()
self._run_context = None
self._message_queue = Queue()
try:
await self._run_context.stop()
finally:
self._run_context = None
self._message_queue = Queue()

async def stop_when_idle(self) -> None:
"""Stop the runtime message processing loop when there is
no outstanding message being processed or queued. This is the most common way to stop the runtime."""
if self._run_context is None:
raise RuntimeError("Runtime is not started")
await self._run_context.stop_when_idle()

self._run_context = None
self._message_queue = Queue()
try:
await self._run_context.stop_when_idle()
finally:
self._run_context = None
self._message_queue = Queue()

async def stop_when(self, condition: Callable[[], bool]) -> None:
"""Stop the runtime message processing loop when the condition is met.
Expand Down
46 changes: 46 additions & 0 deletions python/packages/autogen-core/tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
AgentInstantiationContext,
AgentType,
DefaultTopicId,
MessageContext,
RoutedAgent,
SingleThreadedAgentRuntime,
TopicId,
TypeSubscription,
event,
try_get_known_serializers_for_type,
type_subscription,
)
Expand All @@ -23,6 +26,8 @@
from autogen_test_utils.telemetry_test_utils import MyTestExporter, get_test_tracer_provider
from opentelemetry.sdk.trace import TracerProvider

from autogen_core._default_subscription import default_subscription

test_exporter = MyTestExporter()


Expand Down Expand Up @@ -268,3 +273,44 @@ async def test_default_subscription_publish_to_other_source() -> None:
assert other_long_running_agent.num_calls == 1

await runtime.close()


@default_subscription
class FailingAgent(RoutedAgent):
def __init__(self) -> None:
super().__init__("A failing agent.")

@event
async def on_new_message_event(
self, message: MessageType, ctx: MessageContext
) -> None:
raise ValueError("Test exception")


@pytest.mark.asyncio
async def test_event_handler_exception_propogates() -> None:
runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False)
await FailingAgent.register(runtime, "name", FailingAgent)


with pytest.raises(ValueError, match="Test exception"):
runtime.start()
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
await runtime.stop_when_idle()

await runtime.close()


@pytest.mark.asyncio
async def test_event_handler_exception_multi_message() -> None:
runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False)
await FailingAgent.register(runtime, "name", FailingAgent)

with pytest.raises(ValueError, match="Test exception"):
runtime.start()
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
await runtime.stop_when_idle()

await runtime.close()

0 comments on commit b74f272

Please sign in to comment.