Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] [4/N] API Server: ZMQ/MP Utilities #11541

Merged
merged 18 commits into from
Dec 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
formatting
  • Loading branch information
robertgshaw2-redhat committed Dec 27, 2024
commit a690dcb3297eb38e6b7e0e58a15a6fb0115cd957
1 change: 0 additions & 1 deletion vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def __init__(
self.engine_core = EngineCoreClient.make_client(
vllm_config=vllm_config,
executor_class=executor_class,
usage_context=usage_context,
multiprocess_mode=True,
asyncio_mode=True,
)
Expand Down
84 changes: 4 additions & 80 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,17 @@
import signal
import threading
import time
from dataclasses import dataclass
from multiprocessing.process import BaseProcess
from multiprocessing.connection import Connection
from typing import List, Tuple, Type

import zmq
import zmq.asyncio
from msgspec import msgpack

from vllm.config import CacheConfig, VllmConfig
from vllm.executor.multiproc_worker_utils import get_mp_context
from vllm.logger import init_logger
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import zmq_socket_ctx
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
Expand All @@ -42,7 +39,6 @@ def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[Executor],
usage_context: UsageContext,
):
assert vllm_config.model_config.runner_type != "pooling"

Expand Down Expand Up @@ -134,29 +130,18 @@ def profile(self, is_start: bool = True):
self.model_executor.profile(is_start)


@dataclass
class EngineCoreProcHandle:
proc: BaseProcess
ready_path: str
input_path: str
output_path: str


class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""

READY_STR = "READY"

def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[Executor],
usage_context: UsageContext,
input_path: str,
output_path: str,
ready_path: str,
ready_pipe: Connection,
):
super().__init__(vllm_config, executor_class, usage_context)
super().__init__(vllm_config, executor_class)

# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
Expand All @@ -173,68 +158,7 @@ def __init__(
daemon=True).start()

# Send Readiness signal to EngineClient.
with zmq_socket_ctx(ready_path, zmq.constants.PUSH) as ready_socket:
ready_socket.send_string(EngineCoreProc.READY_STR)

@staticmethod
def wait_for_startup(
proc: BaseProcess,
ready_path: str,
) -> None:
"""Wait until the EngineCore is ready."""

try:
sync_ctx = zmq.Context() # type: ignore[attr-defined]
socket = sync_ctx.socket(zmq.constants.PULL)
socket.connect(ready_path)

# Wait for EngineCore to send EngineCoreProc.READY_STR.
while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
logger.debug("Waiting for EngineCoreProc to startup.")

if not proc.is_alive():
raise RuntimeError("EngineCoreProc failed to start.")

message = socket.recv_string()
assert message == EngineCoreProc.READY_STR

except BaseException as e:
logger.exception(e)
raise e

finally:
sync_ctx.destroy(linger=0)

@staticmethod
def make_engine_core_process(
vllm_config: VllmConfig,
executor_class: Type[Executor],
usage_context: UsageContext,
input_path: str,
output_path: str,
ready_path: str,
) -> EngineCoreProcHandle:
context = get_mp_context()

process_kwargs = {
"input_path": input_path,
"output_path": output_path,
"ready_path": ready_path,
"vllm_config": vllm_config,
"executor_class": executor_class,
"usage_context": usage_context,
}
# Run EngineCore busy loop in background process.
proc = context.Process(target=EngineCoreProc.run_engine_core,
kwargs=process_kwargs)
proc.start()

# Wait for startup
EngineCoreProc.wait_for_startup(proc, ready_path)
return EngineCoreProcHandle(proc=proc,
ready_path=ready_path,
input_path=input_path,
output_path=output_path)
ready_pipe.send({"status": "READY"})

@staticmethod
def run_engine_core(*args, **kwargs):
Expand Down
77 changes: 30 additions & 47 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import os
import weakref
from typing import List, Optional
from typing import List, Type

import msgspec
import zmq
import zmq.asyncio

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import get_open_zmq_ipc_path, kill_process_tree
from vllm.utils import get_open_zmq_ipc_path
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType, EngineCoreRequestUnion)
from vllm.v1.engine.core import (EngineCore, EngineCoreProc,
EngineCoreProcHandle)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import PickleEncoder
from vllm.v1.utils import BackgroundProcessHandler

logger = init_logger(__name__)

Expand All @@ -31,10 +31,10 @@ class EngineCoreClient:

@staticmethod
def make_client(
*args,
vllm_config: VllmConfig,
executor_class: Type[Executor],
multiprocess_mode: bool,
asyncio_mode: bool,
**kwargs,
) -> "EngineCoreClient":

# TODO: support this for debugging purposes.
Expand All @@ -44,12 +44,12 @@ def make_client(
"is not currently supported.")

if multiprocess_mode and asyncio_mode:
return AsyncMPClient(*args, **kwargs)
return AsyncMPClient(vllm_config, executor_class)

if multiprocess_mode and not asyncio_mode:
return SyncMPClient(*args, **kwargs)
return SyncMPClient(vllm_config, executor_class)

return InprocClient(*args, **kwargs)
return InprocClient(vllm_config, executor_class)

def shutdown(self):
pass
Expand Down Expand Up @@ -128,9 +128,9 @@ class MPClient(EngineCoreClient):

def __init__(
self,
*args,
vllm_config: VllmConfig,
executor_class: Type[Executor],
asyncio_mode: bool,
**kwargs,
):
# Serialization setup.
self.encoder = PickleEncoder()
Expand All @@ -143,7 +143,6 @@ def __init__(
self.ctx = zmq.Context() # type: ignore[attr-defined]

# Path for IPC.
ready_path = get_open_zmq_ipc_path()
output_path = get_open_zmq_ipc_path()
input_path = get_open_zmq_ipc_path()

Expand All @@ -156,47 +155,30 @@ def __init__(
self.input_socket.bind(input_path)

# Start EngineCore in background process.
self.proc_handle: Optional[EngineCoreProcHandle]
self.proc_handle = EngineCoreProc.make_engine_core_process(
*args,
input_path=
input_path, # type: ignore[misc] # MyPy incorrectly flags duplicate keywords
output_path=output_path, # type: ignore[misc]
ready_path=ready_path, # type: ignore[misc]
**kwargs,
)
self._finalizer = weakref.finalize(self, self.shutdown)
self.proc_handler = BackgroundProcessHandler(
input_path=input_path,
output_path=output_path,
process_name="EngineCore",
target_fn=EngineCoreProc.run_engine_core,
process_kwargs={
"vllm_config": vllm_config,
"executor_class": executor_class,
})

def shutdown(self):
# Shut down the zmq context.
self.ctx.destroy(linger=0)

if hasattr(self, "proc_handle") and self.proc_handle:
# Shutdown the process if needed.
if self.proc_handle.proc.is_alive():
self.proc_handle.proc.terminate()
self.proc_handle.proc.join(5)

if self.proc_handle.proc.is_alive():
kill_process_tree(self.proc_handle.proc.pid)

# Remove zmq ipc socket files
ipc_sockets = [
self.proc_handle.ready_path, self.proc_handle.output_path,
self.proc_handle.input_path
]
for ipc_socket in ipc_sockets:
socket_file = ipc_socket.replace("ipc://", "")
if os and os.path.exists(socket_file):
os.remove(socket_file)
self.proc_handle = None
if hasattr(self, "proc_handler") and self.proc_handler:
self.proc_handler.shutdown()


class SyncMPClient(MPClient):
"""Synchronous client for multi-proc EngineCore."""

def __init__(self, *args, **kwargs):
super().__init__(*args, asyncio_mode=False, **kwargs)
def __init__(self, vllm_config: VllmConfig,
executor_class: Type[Executor]):
super().__init__(vllm_config, executor_class, asyncio_mode=False)

def get_output(self) -> List[EngineCoreOutput]:

Expand Down Expand Up @@ -225,8 +207,9 @@ def profile(self, is_start: bool = True) -> None:
class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore."""

def __init__(self, *args, **kwargs):
super().__init__(*args, asyncio_mode=True, **kwargs)
def __init__(self, vllm_config: VllmConfig,
executor_class: Type[Executor]):
super().__init__(vllm_config, executor_class, asyncio_mode=True)

async def get_output_async(self) -> List[EngineCoreOutput]:

Expand Down
5 changes: 2 additions & 3 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,8 @@ def __init__(

# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
self.engine_core = EngineCoreClient.make_client(
vllm_config,
executor_class,
usage_context,
vllm_config=vllm_config,
executor_class=executor_class,
multiprocess_mode=multiprocess_mode,
asyncio_mode=False,
)
Expand Down
39 changes: 21 additions & 18 deletions vllm/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,28 +102,22 @@ def shutdown(self):
os.remove(socket_file)


class MPBackgroundProcess:

def __init__(self):
self.proc_handle: Optional[BackgroundProcHandle]
self._finalizer = weakref.finalize(self, self.shutdown)

def __del__(self):
self.shutdown()

def shutdown(self):
if hasattr(self, "proc_handle") and self.proc_handle:
self.proc_handle.shutdown()
self.proc_handle = None

@staticmethod
def wait_for_startup(
class BackgroundProcessHandler:
"""
Utility class to handle creation, readiness, and shutdown
of background processes used by the AsyncLLM and LLMEngine.
"""

def __init__(
self,
input_path: str,
output_path: str,
process_name: str,
target_fn: Callable,
process_kwargs: Dict[Any, Any],
) -> BackgroundProcHandle:
):
self._finalizer = weakref.finalize(self, self.shutdown)

context = get_mp_context()
reader, writer = context.Pipe(duplex=False)

Expand All @@ -143,4 +137,13 @@ def wait_for_startup(
raise RuntimeError(f"{process_name} initialization failed. "
"See root cause above.")

return BackgroundProcHandle(proc, input_path, output_path)
self.proc_handle: Optional[BackgroundProcHandle]
self.proc_handle = BackgroundProcHandle(proc, input_path, output_path)

def __del__(self):
self.shutdown()

def shutdown(self):
if hasattr(self, "proc_handle") and self.proc_handle:
self.proc_handle.shutdown()
self.proc_handle = None
Loading