Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] [4/N] API Server: ZMQ/MP Utilities #11541

Merged
merged 18 commits into from
Dec 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
working
  • Loading branch information
robertgshaw2-redhat committed Dec 27, 2024
commit 340829b38a7a705585b53b32112f1d161db64426
5 changes: 3 additions & 2 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,11 @@ def __init__(

# EngineCore (starts the engine in background process).
self.engine_core = EngineCoreClient.make_client(
vllm_config=vllm_config,
executor_class=executor_class,
multiprocess_mode=True,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=self.log_stats,
)

self.output_handler: Optional[asyncio.Task] = None
Expand Down
14 changes: 10 additions & 4 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

POLLING_TIMEOUT_MS = 5000
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
LOGGING_TIME_S = 1
LOGGING_TIME_S = 5


class EngineCore:
Expand All @@ -39,8 +39,10 @@ def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[Executor],
log_stats: bool = False,
):
assert vllm_config.model_config.runner_type != "pooling"
self.log_stats = log_stats

logger.info("Initializing an LLM engine (v%s) with config: %s",
VLLM_VERSION, vllm_config)
Expand Down Expand Up @@ -135,13 +137,14 @@ class EngineCoreProc(EngineCore):

def __init__(
self,
vllm_config: VllmConfig,
executor_class: Type[Executor],
input_path: str,
output_path: str,
ready_pipe: Connection,
vllm_config: VllmConfig,
executor_class: Type[Executor],
log_stats: bool = False,
):
super().__init__(vllm_config, executor_class)
super().__init__(vllm_config, executor_class, log_stats)

# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
Expand Down Expand Up @@ -233,6 +236,9 @@ def run_busy_loop(self):
def _log_stats(self):
"""Log basic stats every LOGGING_TIME_S"""

if not self.log_stats:
return

now = time.time()

if now - self._last_logging_time > LOGGING_TIME_S:
Expand Down
41 changes: 29 additions & 12 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ class EngineCoreClient:

@staticmethod
def make_client(
vllm_config: VllmConfig,
executor_class: Type[Executor],
multiprocess_mode: bool,
asyncio_mode: bool,
vllm_config: VllmConfig,
executor_class: Type[Executor],
log_stats: bool = False,
) -> "EngineCoreClient":

# TODO: support this for debugging purposes.
Expand All @@ -44,12 +45,12 @@ def make_client(
"is not currently supported.")

if multiprocess_mode and asyncio_mode:
return AsyncMPClient(vllm_config, executor_class)
return AsyncMPClient(vllm_config, executor_class, log_stats)

if multiprocess_mode and not asyncio_mode:
return SyncMPClient(vllm_config, executor_class)
return SyncMPClient(vllm_config, executor_class, log_stats)

return InprocClient(vllm_config, executor_class)
return InprocClient(vllm_config, executor_class, log_stats)

def shutdown(self):
pass
Expand Down Expand Up @@ -128,9 +129,10 @@ class MPClient(EngineCoreClient):

def __init__(
self,
asyncio_mode: bool,
vllm_config: VllmConfig,
executor_class: Type[Executor],
asyncio_mode: bool,
log_stats: bool = False,
):
# Serialization setup.
self.encoder = PickleEncoder()
Expand Down Expand Up @@ -164,6 +166,7 @@ def __init__(
process_kwargs={
"vllm_config": vllm_config,
"executor_class": executor_class,
"log_stats": log_stats,
})

def shutdown(self):
Expand All @@ -178,9 +181,16 @@ def shutdown(self):
class SyncMPClient(MPClient):
"""Synchronous client for multi-proc EngineCore."""

def __init__(self, vllm_config: VllmConfig,
executor_class: Type[Executor]):
super().__init__(vllm_config, executor_class, asyncio_mode=False)
def __init__(self,
vllm_config: VllmConfig,
executor_class: Type[Executor],
log_stats: bool = False):
super().__init__(
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=log_stats,
)

def get_output(self) -> List[EngineCoreOutput]:

Expand Down Expand Up @@ -209,9 +219,16 @@ def profile(self, is_start: bool = True) -> None:
class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore."""

def __init__(self, vllm_config: VllmConfig,
executor_class: Type[Executor]):
super().__init__(vllm_config, executor_class, asyncio_mode=True)
def __init__(self,
vllm_config: VllmConfig,
executor_class: Type[Executor],
log_stats: bool = False):
super().__init__(
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=log_stats,
)

async def get_output_async(self) -> List[EngineCoreOutput]:

Expand Down
5 changes: 3 additions & 2 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,11 @@ def __init__(

# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
self.engine_core = EngineCoreClient.make_client(
vllm_config=vllm_config,
executor_class=executor_class,
multiprocess_mode=multiprocess_mode,
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)

@classmethod
Expand Down
Loading