From b74f2727acb4489d902f24a4ae36c3cb496d655c Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Tue, 25 Feb 2025 20:30:56 -0500 Subject: [PATCH] dont queue, but report immediately --- .../_single_threaded_agent_runtime.py | 46 +++++++++++++------ .../autogen-core/tests/test_runtime.py | 46 +++++++++++++++++++ 2 files changed, 77 insertions(+), 15 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py index fecd02a1d0e9..d299cefc8c69 100644 --- a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py @@ -159,6 +159,7 @@ class SingleThreadedAgentRuntime(AgentRuntime): intervention_handlers (List[InterventionHandler], optional): A list of intervention handlers that can intercept messages before they are sent or published. Defaults to None. tracer_provider (TracerProvider, optional): The tracer provider to use for tracing. Defaults to None. + ignore_unhandled_exceptions (bool, optional): Whether to ignore unhandled exceptions in that occur in agent event handlers. Any background exceptions will be raised on the next call to `process_next` or from an awaited `stop`, `stop_when_idle` or `stop_when`. Note, this does not apply to RPC handlers. Defaults to True. Examples: @@ -248,7 +249,7 @@ def __init__( *, intervention_handlers: List[InterventionHandler] | None = None, tracer_provider: TracerProvider | None = None, - ignore_unhandled_handler_exceptions: bool = True, + ignore_unhandled_exceptions: bool = True, ) -> None: self._tracer_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("SingleThreadedAgentRuntime")) self._message_queue: Queue[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = Queue() @@ -262,7 +263,8 @@ def __init__( self._subscription_manager = SubscriptionManager() self._run_context: RunContext | None = None self._serialization_registry = SerializationRegistry() - self._ignore_unhandled_handler_exceptions = ignore_unhandled_handler_exceptions + self._ignore_unhandled_handler_exceptions = ignore_unhandled_exceptions + self._background_exception: BaseException | None = None @property def unprocessed_messages_count( @@ -523,9 +525,9 @@ async def _on_message(agent: Agent, message_context: MessageContext) -> Any: responses.append(future) await asyncio.gather(*responses) - except BaseException: + except BaseException as e: if not self._ignore_unhandled_handler_exceptions: - raise + self._background_exception = e finally: self._message_queue.task_done() # TODO if responses are given for a publish @@ -554,15 +556,28 @@ async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> self._message_queue.task_done() async def process_next(self) -> None: - """Process the next message in the queue.""" + """Process the next message in the queue. + + If there is an unhandled exception in the background task, it will be raised here. `process_next` cannot be called again after an unhandled exception is raised. + """ await self._process_next() async def _process_next(self) -> None: """Process the next message in the queue.""" + if self._background_exception is not None: + e = self._background_exception + self._background_exception = None + self._message_queue.shutdown(immediate=True) # type: ignore + raise e + try: message_envelope = await self._message_queue.get() except QueueShutDown: + if self._background_exception is not None: + e = self._background_exception + self._background_exception = None + raise e return match message_envelope: @@ -646,10 +661,7 @@ def handle_process_exception(task: Task[Any]) -> None: Args: task: The task that has finished and has potentially raised an exception. """ - try: - task.result() - finally: - self._background_tasks.discard(task) + self._background_tasks.discard(task) task = asyncio.create_task(self._process_publish(message_envelope)) self._background_tasks.add(task) @@ -725,19 +737,23 @@ async def stop(self) -> None: if self._run_context is None: raise RuntimeError("Runtime is not started") - await self._run_context.stop() - self._run_context = None - self._message_queue = Queue() + try: + await self._run_context.stop() + finally: + self._run_context = None + self._message_queue = Queue() async def stop_when_idle(self) -> None: """Stop the runtime message processing loop when there is no outstanding message being processed or queued. This is the most common way to stop the runtime.""" if self._run_context is None: raise RuntimeError("Runtime is not started") - await self._run_context.stop_when_idle() - self._run_context = None - self._message_queue = Queue() + try: + await self._run_context.stop_when_idle() + finally: + self._run_context = None + self._message_queue = Queue() async def stop_when(self, condition: Callable[[], bool]) -> None: """Stop the runtime message processing loop when the condition is met. diff --git a/python/packages/autogen-core/tests/test_runtime.py b/python/packages/autogen-core/tests/test_runtime.py index 9a0e27507480..fd505ddcc4e7 100644 --- a/python/packages/autogen-core/tests/test_runtime.py +++ b/python/packages/autogen-core/tests/test_runtime.py @@ -6,9 +6,12 @@ AgentInstantiationContext, AgentType, DefaultTopicId, + MessageContext, + RoutedAgent, SingleThreadedAgentRuntime, TopicId, TypeSubscription, + event, try_get_known_serializers_for_type, type_subscription, ) @@ -23,6 +26,8 @@ from autogen_test_utils.telemetry_test_utils import MyTestExporter, get_test_tracer_provider from opentelemetry.sdk.trace import TracerProvider +from autogen_core._default_subscription import default_subscription + test_exporter = MyTestExporter() @@ -268,3 +273,44 @@ async def test_default_subscription_publish_to_other_source() -> None: assert other_long_running_agent.num_calls == 1 await runtime.close() + + +@default_subscription +class FailingAgent(RoutedAgent): + def __init__(self) -> None: + super().__init__("A failing agent.") + + @event + async def on_new_message_event( + self, message: MessageType, ctx: MessageContext + ) -> None: + raise ValueError("Test exception") + + +@pytest.mark.asyncio +async def test_event_handler_exception_propogates() -> None: + runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False) + await FailingAgent.register(runtime, "name", FailingAgent) + + + with pytest.raises(ValueError, match="Test exception"): + runtime.start() + await runtime.publish_message(MessageType(), topic_id=DefaultTopicId()) + await runtime.stop_when_idle() + + await runtime.close() + + +@pytest.mark.asyncio +async def test_event_handler_exception_multi_message() -> None: + runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False) + await FailingAgent.register(runtime, "name", FailingAgent) + + with pytest.raises(ValueError, match="Test exception"): + runtime.start() + await runtime.publish_message(MessageType(), topic_id=DefaultTopicId()) + await runtime.publish_message(MessageType(), topic_id=DefaultTopicId()) + await runtime.publish_message(MessageType(), topic_id=DefaultTopicId()) + await runtime.stop_when_idle() + + await runtime.close() \ No newline at end of file