diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py index daa4ad65101d..621ed9511eb4 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py @@ -7,6 +7,7 @@ from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Generic, Sequence, Set, Tuple, TypeVar from autogen_core import TopicId +from autogen_core._agent_id import AgentId from autogen_core._runtime_impl_helpers import SubscriptionManager from ._constants import GRPC_IMPORT_ERROR_STR @@ -100,6 +101,9 @@ def __init__(self) -> None: self._data_connections: Dict[ ClientConnectionId, ChannelConnection[agent_worker_pb2.Message, agent_worker_pb2.Message] ] = {} + self._control_connections: Dict[ + ClientConnectionId, ChannelConnection[agent_worker_pb2.ControlMessage, agent_worker_pb2.ControlMessage] + ] = {} self._agent_type_to_client_id_lock = asyncio.Lock() self._agent_type_to_client_id: Dict[str, ClientConnectionId] = {} self._pending_responses: Dict[ClientConnectionId, Dict[str, Future[Any]]] = {} @@ -140,7 +144,23 @@ async def OpenControlChannel( # type: ignore request_iterator: AsyncIterator[agent_worker_pb2.ControlMessage], context: grpc.aio.ServicerContext[agent_worker_pb2.ControlMessage, agent_worker_pb2.ControlMessage], ) -> AsyncIterator[agent_worker_pb2.ControlMessage]: - raise NotImplementedError("Method not implemented.") + client_id = await get_client_id_or_abort(context) + + async def handle_callback(message: agent_worker_pb2.ControlMessage) -> None: + await self._receive_control_message(client_id, message) + + connection = CallbackChannelConnection[agent_worker_pb2.ControlMessage, agent_worker_pb2.ControlMessage]( + request_iterator, client_id, handle_callback=handle_callback + ) + self._control_connections[client_id] = connection + logger.info(f"Client {client_id} connected.") + + try: + async for message in connection: + yield message + finally: + # Clean up the client connection. + del self._control_connections[client_id] async def _on_client_disconnect(self, client_id: ClientConnectionId) -> None: async with self._agent_type_to_client_id_lock: @@ -182,6 +202,29 @@ async def _receive_message(self, client_id: ClientConnectionId, message: agent_w case None: logger.warning("Received empty message") + async def _receive_control_message( + self, client_id: ClientConnectionId, message: agent_worker_pb2.ControlMessage + ) -> None: + logger.info(f"Received message from client {client_id}: {message}") + destination = message.destination + if destination.startswith("agentid="): + agent_id = AgentId.from_str(destination[len("agentid=") :]) + target_client_id = self._agent_type_to_client_id.get(agent_id.type) + if target_client_id is None: + logger.error(f"Agent client id not found for agent type {agent_id.type}.") + return + elif destination.startswith("clientid="): + target_client_id = destination[len("clientid=") :] + else: + logger.error(f"Invalid destination {destination}") + return + + target_send_queue = self._control_connections.get(target_client_id) + if target_send_queue is None: + logger.error(f"Client {target_client_id} not found, failed to deliver message.") + return + await target_send_queue.send(message) + async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: ClientConnectionId) -> None: # Deliver the message to a client given the target agent type. async with self._agent_type_to_client_id_lock: