Skip to content

Commit

Permalink
[Refactor] use zmq socket pool
Browse files Browse the repository at this point in the history
  • Loading branch information
Tong0217 committed Feb 27, 2025
1 parent 5743023 commit 9760db0
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 108 deletions.
2 changes: 1 addition & 1 deletion llumnix/backends/vllm/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(self,
self.put_queue_loop_thread = threading.Thread(
target=self._start_put_queue_loop, args=(), daemon=True, name="put_queue_loop"
)
self.async_put_queue_actor = ray.remote(
self.async_put_queue_actor: AsyncPutQueueActor = ray.remote(
num_cpus=1,
scheduling_strategy=scheduling_strategy
)(AsyncPutQueueActor).remote(instance_id, request_output_queue_type)
Expand Down
2 changes: 1 addition & 1 deletion llumnix/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

# llumnix/queue/zmq_server.py
RPC_SOCKET_LIMIT_CUTOFF: int = 2000
RPC_ZMQ_HWM: int = 0
RPC_ZMQ_HWM: int = int(1e5)
RETRY_BIND_ADDRESS_INTERVAL: float = 10.0
MAX_BIND_ADDRESS_RETRY_TIMES: int = 10

Expand Down
194 changes: 120 additions & 74 deletions llumnix/queue/zmq_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any
from contextlib import contextmanager
from asyncio import Lock, Queue, QueueEmpty
import asyncio
from typing import Any, Dict
from collections.abc import Iterable
import time

Expand All @@ -21,105 +22,150 @@
import cloudpickle

from llumnix.logging.logger import init_logger
from llumnix.queue.queue_client_base import QueueClientBase
from llumnix.server_info import ServerInfo

from llumnix.queue.zmq_utils import (RPC_SUCCESS_STR, RPC_REQUEST_TYPE, RPCClientClosedError,
RPCUtilityRequest, RPCPutNoWaitQueueRequest, RPCPutNoWaitBatchQueueRequest,
get_open_zmq_ipc_path)
from llumnix.constants import RPC_GET_DATA_TIMEOUT_MS, RPC_SOCKET_LIMIT_CUTOFF, RPC_ZMQ_HWM
from llumnix.queue.zmq_utils import (
RPC_SUCCESS_STR,
RPC_REQUEST_TYPE,
RPCUtilityRequest,
RPCPutNoWaitQueueRequest,
RPCPutNoWaitBatchQueueRequest,
get_open_zmq_ipc_path,
get_zmq_connection_name,
)
from llumnix.constants import (
RPC_GET_DATA_TIMEOUT_MS,
)
from llumnix.metrics.timestamps import set_timestamp

logger = init_logger(__name__)


class ZmqClient:
def __init__(self):
self.context = zmq.asyncio.Context()
self._data_timeout = RPC_GET_DATA_TIMEOUT_MS
self._errored = False

# Maximum number of sockets that can be opened (typically 65536).
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT)
if socket_limit < RPC_SOCKET_LIMIT_CUTOFF:
raise ValueError(
f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps "
"the number of concurrent requests Llumnix can process.")

# We only have 1 ipc connection that uses unix sockets, so
# safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will
# not run into ulimit issues)
self.context.set(zmq.constants.MAX_SOCKETS, socket_limit)

# This function is not called explicitly.
def close(self):
self.context.destroy()

@contextmanager
def to_socket(self, rpc_path):
# Raise a sensible error if the client was already closed.
# This can happen if a server shutdown is triggered but some coroutines
# are still running requests.
# There should not be a race condition with this check because we don't
# yield to the event loop between here and opening the socket.
if self.context.closed:
raise RPCClientClosedError("The ZMQ client has already shut down")

# Note that we use DEALER to enable asynchronous communication
# to enable streaming.
socket = self.context.socket(zmq.constants.DEALER)
socket.set_hwm(RPC_ZMQ_HWM)
class AsyncZmqSocketPool:
def __init__(self, context: zmq.asyncio.Context, ip, port, pool_size=10):
self.context = context
self.ip = ip
self.port = port
self.pool_size = pool_size
self.pool = Queue(maxsize=pool_size)
self.lock = asyncio.Lock()

async def _create_socket(self) -> zmq.asyncio.Socket:
socket = self.context.socket(zmq.DEALER)
dst_address = get_open_zmq_ipc_path(self.ip, self.port)
socket.connect(dst_address)
return socket

async def get_socket(self):
socket = None
try:
socket.connect(rpc_path)
yield socket
finally:
socket = self.pool.get_nowait()
except QueueEmpty:
async with self.lock:
socket = await self._create_socket()
return socket

def full(self):
return self.pool.full()

def put(self, socket: zmq.asyncio.Socket):
self.pool.put_nowait(socket)

async def close_all_connections(self):
while not self.pool.empty():
socket = await self.pool.get()
socket.close(linger=0)


class ZmqSocketPoolFactory:
def __init__(self, context, pool_size=10):
self.context: zmq.asyncio.Context = context
self.pool_size: int = pool_size
self.pools: Dict[str, AsyncZmqSocketPool] = {}
self.lock: asyncio.Lock = Lock()

async def get_pool(self, ip, port) -> AsyncZmqSocketPool:
async with self.lock:
dst_name = get_zmq_connection_name(ip, port)
if dst_name not in self.pools:
self.pools[dst_name] = AsyncZmqSocketPool(
ip=ip,
port=port,
context=self.context,
pool_size=self.pool_size,
)
return self.pools[dst_name]

async def close_all_pools(self):
for pool in self.pools.values():
await pool.close_all_connections()


class ZmqClient(QueueClientBase):
def __init__(self):
self.context = zmq.asyncio.Context(8)
self.socket_pool_factory: ZmqSocketPoolFactory = ZmqSocketPoolFactory(
context=self.context, pool_size=10
)
self.zmq_timeout_ms: int = RPC_GET_DATA_TIMEOUT_MS
self._conn_lock: asyncio.Lock = Lock()

async def _send_one_way_rpc_request(
self,
request: RPC_REQUEST_TYPE,
rpc_path: str,
error_message: str):
async def do_rpc_call(socket: zmq.asyncio.Socket,
request: RPC_REQUEST_TYPE):
self,
request: RPC_REQUEST_TYPE,
ip: str,
port: int,
error_message: str,
):
async def do_rpc_call(socket: zmq.asyncio.Socket, request: RPC_REQUEST_TYPE):

await socket.send_multipart([cloudpickle.dumps(request)])

if await socket.poll(timeout=self._data_timeout) == 0:
raise TimeoutError("Server didn't reply within "
f"{self._data_timeout} ms")
if await socket.poll(timeout=self.zmq_timeout_ms) == 0:
raise TimeoutError(
"Server didn't reply within " f"{self.zmq_timeout_ms} ms"
)

return cloudpickle.loads(await socket.recv())

with self.to_socket(rpc_path) as socket:
response = await do_rpc_call(socket, request)
socket_pool = await self.socket_pool_factory.get_pool(ip, port)
socket = await socket_pool.get_socket()
response = await do_rpc_call(socket, request)

if not isinstance(response, str) or response != RPC_SUCCESS_STR:
socket.close(linger=0)
if isinstance(response, Exception):
logger.error(error_message)
logger.error(f"{error_message}:{response}")
raise response
raise ValueError(error_message)

async def wait_for_server_rpc(self,
server_info: ServerInfo):
rpc_path = get_open_zmq_ipc_path(server_info.request_output_queue_ip, server_info.request_output_queue_port)
if not socket_pool.full():
socket_pool.put(socket)
else:
socket.close(linger=0)

async def wait_for_server_rpc(self, server_info: ServerInfo):
await self._send_one_way_rpc_request(
request=RPCUtilityRequest.IS_SERVER_READY,
rpc_path=rpc_path,
error_message="Unable to start RPC Server")
request=RPCUtilityRequest.IS_SERVER_READY,
ip=server_info.request_output_queue_ip,
port=server_info.request_output_queue_port,
error_message="Unable to start RPC Server",
)

async def put_nowait(self, item: Any, server_info: ServerInfo):
rpc_path = get_open_zmq_ipc_path(server_info.request_output_queue_ip, server_info.request_output_queue_port)
set_timestamp(item, 'queue_client_send_timestamp', time.time())
set_timestamp(item, "queue_client_send_timestamp", time.time())
await self._send_one_way_rpc_request(
request=RPCPutNoWaitQueueRequest(item=item),
rpc_path=rpc_path,
error_message="Unable to put items into queue.")
request=RPCPutNoWaitQueueRequest(item=item),
ip=server_info.request_output_queue_ip,
port=server_info.request_output_queue_port,
error_message="Unable to put items into queue.",
)

async def put_nowait_batch(self, items: Iterable, server_info: ServerInfo):
rpc_path = get_open_zmq_ipc_path(server_info.request_output_queue_ip, server_info.request_output_queue_port)
set_timestamp(items, 'queue_client_send_timestamp', time.time())
await self._send_one_way_rpc_request(
request=RPCPutNoWaitBatchQueueRequest(items=items),
rpc_path=rpc_path,
error_message="Unable to put items into queue.")
request=RPCPutNoWaitBatchQueueRequest(items=items),
ip=server_info.request_output_queue_ip,
port=server_info.request_output_queue_port,
error_message="Unable to put items into queue.",
)
58 changes: 26 additions & 32 deletions llumnix/queue/zmq_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import asyncio
import time
import traceback
from typing import (Coroutine, Any)
from typing_extensions import Never

Expand All @@ -21,12 +22,12 @@
import zmq.error
import cloudpickle

from llumnix.queue.queue_server_base import QueueServerBase
from llumnix.queue.zmq_utils import (RPC_SUCCESS_STR, RPCPutNoWaitQueueRequest,
RPCPutNoWaitBatchQueueRequest, RPCUtilityRequest,
get_open_zmq_ipc_path)
from llumnix.logging.logger import init_logger
from llumnix.constants import (RPC_SOCKET_LIMIT_CUTOFF, RPC_ZMQ_HWM, RETRY_BIND_ADDRESS_INTERVAL,
MAX_BIND_ADDRESS_RETRY_TIMES)
from llumnix.constants import (RETRY_BIND_ADDRESS_INTERVAL, MAX_BIND_ADDRESS_RETRY_TIMES)
from llumnix.metrics.timestamps import set_timestamp

logger = init_logger(__name__)
Expand All @@ -39,41 +40,25 @@ class Full(Exception):
pass


class ZmqServer:
class ZmqServer(QueueServerBase):
def __init__(self, ip: str, port: int, maxsize=0):
rpc_path = get_open_zmq_ipc_path(ip, port)

self.context = zmq.asyncio.Context()

# Maximum number of sockets that can be opened (typically 65536).
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT)
if socket_limit < RPC_SOCKET_LIMIT_CUTOFF:
raise ValueError(
f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps "
"the number of concurrent requests Llumnix can process.")

# We only have 1 ipc connection that uses unix sockets, so
# safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will
# not run into ulimit issues)
self.context.set(zmq.constants.MAX_SOCKETS, socket_limit)

self.socket = self.context.socket(zmq.constants.ROUTER)
self.socket.set_hwm(RPC_ZMQ_HWM)

self.context: zmq.asyncio.Context = zmq.asyncio.Context(8)
self.socket = self.context.socket(zmq.ROUTER)
self._stop_event = asyncio.Event()
endpoint = get_open_zmq_ipc_path(ip, port)
for attempt in range(MAX_BIND_ADDRESS_RETRY_TIMES):
try:
self.socket.bind(rpc_path)
logger.info("QueueServer's socket bind to: {}".format(rpc_path))
self.socket.bind(endpoint)
logger.info("QueueServer's socket bind to: {}".format(endpoint))
break
# pylint: disable=broad-except
except Exception as e:
logger.warning("QueueServer's socket bind to {} failed, exception: {}".format(rpc_path, e))
logger.warning("QueueServer's socket bind to {} failed, exception: {}".format(endpoint, e))
if attempt < MAX_BIND_ADDRESS_RETRY_TIMES - 1:
logger.warning("{} already in use, sleep {}s, and retry bind to it again.".format(rpc_path, RETRY_BIND_ADDRESS_INTERVAL))
logger.warning("{} already in use, sleep {}s, and retry bind to it again.".format(endpoint, RETRY_BIND_ADDRESS_INTERVAL))
time.sleep(RETRY_BIND_ADDRESS_INTERVAL)
else:
logger.error("{} still in use after {} times retries.".format(rpc_path, MAX_BIND_ADDRESS_RETRY_TIMES))
logger.error("{} still in use after {} times retries.".format(endpoint, MAX_BIND_ADDRESS_RETRY_TIMES))
raise

self.maxsize = maxsize
Expand Down Expand Up @@ -127,15 +112,15 @@ def get_nowait_batch(self, num_items):
)
return [self.queue.get_nowait() for _ in range(num_items)]

def _make_handler_coro(self, identity,
async def _make_handler_coro(self, identity,
message) -> Coroutine[Any, Any, Never]:
request = cloudpickle.loads(message)
if request == RPCUtilityRequest.IS_SERVER_READY:
return self._is_server_ready(identity)
return await self._is_server_ready(identity)
if isinstance(request, RPCPutNoWaitQueueRequest):
return self._put_nowait(identity, request)
return await self._put_nowait(identity, request)
if isinstance(request, RPCPutNoWaitBatchQueueRequest):
return self._put_nowait_batch(identity, request)
return await self._put_nowait_batch(identity, request)

raise ValueError(f"Unknown RPCRequest type: {request}")

Expand Down Expand Up @@ -166,6 +151,15 @@ async def _put_nowait_batch(self, identity, put_nowait_batch_queue_request: RPCP
await self.socket.send_multipart([identity, cloudpickle.dumps(e)])

async def run_server_loop(self):
# while not self._stop_event.is_set():
# try:
# identity, message = await self.socket.recv_multipart()
# except Exception:
# logger.error('Failed to receive message from zmq clent.')
# logger.error(traceback.format_exc())
# continue
# await self._make_handler_coro(identity, message)

running_tasks = set()
while True:
identity, message = await self.socket.recv_multipart()
Expand Down
3 changes: 3 additions & 0 deletions llumnix/queue/zmq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,6 @@ class RPCClientClosedError(Exception):

def get_open_zmq_ipc_path(ip, port) -> str:
return "tcp://{}:{}".format(ip, port)

def get_zmq_connection_name(ip, port) -> str:
return "{}:{}".format(ip, port)

0 comments on commit 9760db0

Please sign in to comment.