From 9760db09aa556c3a30817e899a1c76374d7a403c Mon Sep 17 00:00:00 2001 From: tongchenghao Date: Thu, 27 Feb 2025 14:19:39 +0800 Subject: [PATCH] [Refactor] use zmq socket pool --- llumnix/backends/vllm/llm_engine.py | 2 +- llumnix/constants.py | 2 +- llumnix/queue/zmq_client.py | 194 +++++++++++++++++----------- llumnix/queue/zmq_server.py | 58 ++++----- llumnix/queue/zmq_utils.py | 3 + 5 files changed, 151 insertions(+), 108 deletions(-) diff --git a/llumnix/backends/vllm/llm_engine.py b/llumnix/backends/vllm/llm_engine.py index ad01d54c..8c4bcc10 100644 --- a/llumnix/backends/vllm/llm_engine.py +++ b/llumnix/backends/vllm/llm_engine.py @@ -83,7 +83,7 @@ def __init__(self, self.put_queue_loop_thread = threading.Thread( target=self._start_put_queue_loop, args=(), daemon=True, name="put_queue_loop" ) - self.async_put_queue_actor = ray.remote( + self.async_put_queue_actor: AsyncPutQueueActor = ray.remote( num_cpus=1, scheduling_strategy=scheduling_strategy )(AsyncPutQueueActor).remote(instance_id, request_output_queue_type) diff --git a/llumnix/constants.py b/llumnix/constants.py index b6294ad3..47b3d106 100644 --- a/llumnix/constants.py +++ b/llumnix/constants.py @@ -45,7 +45,7 @@ # llumnix/queue/zmq_server.py RPC_SOCKET_LIMIT_CUTOFF: int = 2000 -RPC_ZMQ_HWM: int = 0 +RPC_ZMQ_HWM: int = int(1e5) RETRY_BIND_ADDRESS_INTERVAL: float = 10.0 MAX_BIND_ADDRESS_RETRY_TIMES: int = 10 diff --git a/llumnix/queue/zmq_client.py b/llumnix/queue/zmq_client.py index cdec50ab..6746b16f 100644 --- a/llumnix/queue/zmq_client.py +++ b/llumnix/queue/zmq_client.py @@ -11,8 +11,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any -from contextlib import contextmanager +from asyncio import Lock, Queue, QueueEmpty +import asyncio +from typing import Any, Dict from collections.abc import Iterable import time @@ -21,105 +22,150 @@ import cloudpickle from llumnix.logging.logger import init_logger +from llumnix.queue.queue_client_base import QueueClientBase from llumnix.server_info import ServerInfo -from llumnix.queue.zmq_utils import (RPC_SUCCESS_STR, RPC_REQUEST_TYPE, RPCClientClosedError, - RPCUtilityRequest, RPCPutNoWaitQueueRequest, RPCPutNoWaitBatchQueueRequest, - get_open_zmq_ipc_path) -from llumnix.constants import RPC_GET_DATA_TIMEOUT_MS, RPC_SOCKET_LIMIT_CUTOFF, RPC_ZMQ_HWM +from llumnix.queue.zmq_utils import ( + RPC_SUCCESS_STR, + RPC_REQUEST_TYPE, + RPCUtilityRequest, + RPCPutNoWaitQueueRequest, + RPCPutNoWaitBatchQueueRequest, + get_open_zmq_ipc_path, + get_zmq_connection_name, +) +from llumnix.constants import ( + RPC_GET_DATA_TIMEOUT_MS, +) from llumnix.metrics.timestamps import set_timestamp logger = init_logger(__name__) -class ZmqClient: - def __init__(self): - self.context = zmq.asyncio.Context() - self._data_timeout = RPC_GET_DATA_TIMEOUT_MS - self._errored = False - - # Maximum number of sockets that can be opened (typically 65536). - # ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get) - socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT) - if socket_limit < RPC_SOCKET_LIMIT_CUTOFF: - raise ValueError( - f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps " - "the number of concurrent requests Llumnix can process.") - - # We only have 1 ipc connection that uses unix sockets, so - # safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will - # not run into ulimit issues) - self.context.set(zmq.constants.MAX_SOCKETS, socket_limit) - - # This function is not called explicitly. - def close(self): - self.context.destroy() - - @contextmanager - def to_socket(self, rpc_path): - # Raise a sensible error if the client was already closed. - # This can happen if a server shutdown is triggered but some coroutines - # are still running requests. - # There should not be a race condition with this check because we don't - # yield to the event loop between here and opening the socket. - if self.context.closed: - raise RPCClientClosedError("The ZMQ client has already shut down") - - # Note that we use DEALER to enable asynchronous communication - # to enable streaming. - socket = self.context.socket(zmq.constants.DEALER) - socket.set_hwm(RPC_ZMQ_HWM) +class AsyncZmqSocketPool: + def __init__(self, context: zmq.asyncio.Context, ip, port, pool_size=10): + self.context = context + self.ip = ip + self.port = port + self.pool_size = pool_size + self.pool = Queue(maxsize=pool_size) + self.lock = asyncio.Lock() + + async def _create_socket(self) -> zmq.asyncio.Socket: + socket = self.context.socket(zmq.DEALER) + dst_address = get_open_zmq_ipc_path(self.ip, self.port) + socket.connect(dst_address) + return socket + + async def get_socket(self): + socket = None try: - socket.connect(rpc_path) - yield socket - finally: + socket = self.pool.get_nowait() + except QueueEmpty: + async with self.lock: + socket = await self._create_socket() + return socket + + def full(self): + return self.pool.full() + + def put(self, socket: zmq.asyncio.Socket): + self.pool.put_nowait(socket) + + async def close_all_connections(self): + while not self.pool.empty(): + socket = await self.pool.get() socket.close(linger=0) + +class ZmqSocketPoolFactory: + def __init__(self, context, pool_size=10): + self.context: zmq.asyncio.Context = context + self.pool_size: int = pool_size + self.pools: Dict[str, AsyncZmqSocketPool] = {} + self.lock: asyncio.Lock = Lock() + + async def get_pool(self, ip, port) -> AsyncZmqSocketPool: + async with self.lock: + dst_name = get_zmq_connection_name(ip, port) + if dst_name not in self.pools: + self.pools[dst_name] = AsyncZmqSocketPool( + ip=ip, + port=port, + context=self.context, + pool_size=self.pool_size, + ) + return self.pools[dst_name] + + async def close_all_pools(self): + for pool in self.pools.values(): + await pool.close_all_connections() + + +class ZmqClient(QueueClientBase): + def __init__(self): + self.context = zmq.asyncio.Context(8) + self.socket_pool_factory: ZmqSocketPoolFactory = ZmqSocketPoolFactory( + context=self.context, pool_size=10 + ) + self.zmq_timeout_ms: int = RPC_GET_DATA_TIMEOUT_MS + self._conn_lock: asyncio.Lock = Lock() + async def _send_one_way_rpc_request( - self, - request: RPC_REQUEST_TYPE, - rpc_path: str, - error_message: str): - async def do_rpc_call(socket: zmq.asyncio.Socket, - request: RPC_REQUEST_TYPE): + self, + request: RPC_REQUEST_TYPE, + ip: str, + port: int, + error_message: str, + ): + async def do_rpc_call(socket: zmq.asyncio.Socket, request: RPC_REQUEST_TYPE): await socket.send_multipart([cloudpickle.dumps(request)]) - if await socket.poll(timeout=self._data_timeout) == 0: - raise TimeoutError("Server didn't reply within " - f"{self._data_timeout} ms") + if await socket.poll(timeout=self.zmq_timeout_ms) == 0: + raise TimeoutError( + "Server didn't reply within " f"{self.zmq_timeout_ms} ms" + ) return cloudpickle.loads(await socket.recv()) - with self.to_socket(rpc_path) as socket: - response = await do_rpc_call(socket, request) + socket_pool = await self.socket_pool_factory.get_pool(ip, port) + socket = await socket_pool.get_socket() + response = await do_rpc_call(socket, request) if not isinstance(response, str) or response != RPC_SUCCESS_STR: + socket.close(linger=0) if isinstance(response, Exception): - logger.error(error_message) + logger.error(f"{error_message}:{response}") raise response raise ValueError(error_message) - async def wait_for_server_rpc(self, - server_info: ServerInfo): - rpc_path = get_open_zmq_ipc_path(server_info.request_output_queue_ip, server_info.request_output_queue_port) + if not socket_pool.full(): + socket_pool.put(socket) + else: + socket.close(linger=0) + + async def wait_for_server_rpc(self, server_info: ServerInfo): await self._send_one_way_rpc_request( - request=RPCUtilityRequest.IS_SERVER_READY, - rpc_path=rpc_path, - error_message="Unable to start RPC Server") + request=RPCUtilityRequest.IS_SERVER_READY, + ip=server_info.request_output_queue_ip, + port=server_info.request_output_queue_port, + error_message="Unable to start RPC Server", + ) async def put_nowait(self, item: Any, server_info: ServerInfo): - rpc_path = get_open_zmq_ipc_path(server_info.request_output_queue_ip, server_info.request_output_queue_port) - set_timestamp(item, 'queue_client_send_timestamp', time.time()) + set_timestamp(item, "queue_client_send_timestamp", time.time()) await self._send_one_way_rpc_request( - request=RPCPutNoWaitQueueRequest(item=item), - rpc_path=rpc_path, - error_message="Unable to put items into queue.") + request=RPCPutNoWaitQueueRequest(item=item), + ip=server_info.request_output_queue_ip, + port=server_info.request_output_queue_port, + error_message="Unable to put items into queue.", + ) async def put_nowait_batch(self, items: Iterable, server_info: ServerInfo): - rpc_path = get_open_zmq_ipc_path(server_info.request_output_queue_ip, server_info.request_output_queue_port) - set_timestamp(items, 'queue_client_send_timestamp', time.time()) await self._send_one_way_rpc_request( - request=RPCPutNoWaitBatchQueueRequest(items=items), - rpc_path=rpc_path, - error_message="Unable to put items into queue.") + request=RPCPutNoWaitBatchQueueRequest(items=items), + ip=server_info.request_output_queue_ip, + port=server_info.request_output_queue_port, + error_message="Unable to put items into queue.", + ) diff --git a/llumnix/queue/zmq_server.py b/llumnix/queue/zmq_server.py index e4c8c24c..cd5dcfc8 100644 --- a/llumnix/queue/zmq_server.py +++ b/llumnix/queue/zmq_server.py @@ -13,6 +13,7 @@ import asyncio import time +import traceback from typing import (Coroutine, Any) from typing_extensions import Never @@ -21,12 +22,12 @@ import zmq.error import cloudpickle +from llumnix.queue.queue_server_base import QueueServerBase from llumnix.queue.zmq_utils import (RPC_SUCCESS_STR, RPCPutNoWaitQueueRequest, RPCPutNoWaitBatchQueueRequest, RPCUtilityRequest, get_open_zmq_ipc_path) from llumnix.logging.logger import init_logger -from llumnix.constants import (RPC_SOCKET_LIMIT_CUTOFF, RPC_ZMQ_HWM, RETRY_BIND_ADDRESS_INTERVAL, - MAX_BIND_ADDRESS_RETRY_TIMES) +from llumnix.constants import (RETRY_BIND_ADDRESS_INTERVAL, MAX_BIND_ADDRESS_RETRY_TIMES) from llumnix.metrics.timestamps import set_timestamp logger = init_logger(__name__) @@ -39,41 +40,25 @@ class Full(Exception): pass -class ZmqServer: +class ZmqServer(QueueServerBase): def __init__(self, ip: str, port: int, maxsize=0): - rpc_path = get_open_zmq_ipc_path(ip, port) - - self.context = zmq.asyncio.Context() - - # Maximum number of sockets that can be opened (typically 65536). - # ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get) - socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT) - if socket_limit < RPC_SOCKET_LIMIT_CUTOFF: - raise ValueError( - f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps " - "the number of concurrent requests Llumnix can process.") - - # We only have 1 ipc connection that uses unix sockets, so - # safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will - # not run into ulimit issues) - self.context.set(zmq.constants.MAX_SOCKETS, socket_limit) - - self.socket = self.context.socket(zmq.constants.ROUTER) - self.socket.set_hwm(RPC_ZMQ_HWM) - + self.context: zmq.asyncio.Context = zmq.asyncio.Context(8) + self.socket = self.context.socket(zmq.ROUTER) + self._stop_event = asyncio.Event() + endpoint = get_open_zmq_ipc_path(ip, port) for attempt in range(MAX_BIND_ADDRESS_RETRY_TIMES): try: - self.socket.bind(rpc_path) - logger.info("QueueServer's socket bind to: {}".format(rpc_path)) + self.socket.bind(endpoint) + logger.info("QueueServer's socket bind to: {}".format(endpoint)) break # pylint: disable=broad-except except Exception as e: - logger.warning("QueueServer's socket bind to {} failed, exception: {}".format(rpc_path, e)) + logger.warning("QueueServer's socket bind to {} failed, exception: {}".format(endpoint, e)) if attempt < MAX_BIND_ADDRESS_RETRY_TIMES - 1: - logger.warning("{} already in use, sleep {}s, and retry bind to it again.".format(rpc_path, RETRY_BIND_ADDRESS_INTERVAL)) + logger.warning("{} already in use, sleep {}s, and retry bind to it again.".format(endpoint, RETRY_BIND_ADDRESS_INTERVAL)) time.sleep(RETRY_BIND_ADDRESS_INTERVAL) else: - logger.error("{} still in use after {} times retries.".format(rpc_path, MAX_BIND_ADDRESS_RETRY_TIMES)) + logger.error("{} still in use after {} times retries.".format(endpoint, MAX_BIND_ADDRESS_RETRY_TIMES)) raise self.maxsize = maxsize @@ -127,15 +112,15 @@ def get_nowait_batch(self, num_items): ) return [self.queue.get_nowait() for _ in range(num_items)] - def _make_handler_coro(self, identity, + async def _make_handler_coro(self, identity, message) -> Coroutine[Any, Any, Never]: request = cloudpickle.loads(message) if request == RPCUtilityRequest.IS_SERVER_READY: - return self._is_server_ready(identity) + return await self._is_server_ready(identity) if isinstance(request, RPCPutNoWaitQueueRequest): - return self._put_nowait(identity, request) + return await self._put_nowait(identity, request) if isinstance(request, RPCPutNoWaitBatchQueueRequest): - return self._put_nowait_batch(identity, request) + return await self._put_nowait_batch(identity, request) raise ValueError(f"Unknown RPCRequest type: {request}") @@ -166,6 +151,15 @@ async def _put_nowait_batch(self, identity, put_nowait_batch_queue_request: RPCP await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) async def run_server_loop(self): + # while not self._stop_event.is_set(): + # try: + # identity, message = await self.socket.recv_multipart() + # except Exception: + # logger.error('Failed to receive message from zmq clent.') + # logger.error(traceback.format_exc()) + # continue + # await self._make_handler_coro(identity, message) + running_tasks = set() while True: identity, message = await self.socket.recv_multipart() diff --git a/llumnix/queue/zmq_utils.py b/llumnix/queue/zmq_utils.py index ade863ba..542410c8 100644 --- a/llumnix/queue/zmq_utils.py +++ b/llumnix/queue/zmq_utils.py @@ -44,3 +44,6 @@ class RPCClientClosedError(Exception): def get_open_zmq_ipc_path(ip, port) -> str: return "tcp://{}:{}".format(ip, port) + +def get_zmq_connection_name(ip, port) -> str: + return "{}:{}".format(ip, port)