diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index a093a2b29278a..6cae76f74603d 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -26,6 +26,11 @@ class RequestOutput: finished: bool = False +@dataclass +class MockModelConfig: + use_async_output_proc = True + + class MockEngine: def __init__(self): @@ -35,6 +40,7 @@ def __init__(self): self.request_id = None # Ugly, remove dependency when possible self.parallel_config = ParallelConfig(1, 1, False) + self.model_config = MockModelConfig() async def step_async(self, virtual_engine): # PP size is 1, ignore virtual engine @@ -80,7 +86,7 @@ class MockAsyncLLMEngine(AsyncLLMEngine): @pytest.mark.asyncio async def test_new_requests_event(): - engine = MockAsyncLLMEngine(worker_use_ray=False) + engine = MockAsyncLLMEngine() engine.start_background_loop() await asyncio.sleep(0.01) assert engine.engine.step_calls == 0 @@ -113,7 +119,7 @@ async def test_new_requests_event(): assert engine.engine.add_request_calls == 3 assert engine.engine.step_calls == old_step_calls + 1 - engine = MockAsyncLLMEngine(worker_use_ray=True) + engine = MockAsyncLLMEngine() assert engine.get_model_config() is not None assert engine.get_tokenizer() is not None assert engine.get_decoding_config() is not None diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 8a07ce1c965e1..410e6ffaa2d50 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,8 +1,10 @@ import asyncio import time +import weakref from functools import partial from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Type, Union) +from weakref import ReferenceType import vllm.envs as envs from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, @@ -26,6 +28,7 @@ from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext +from vllm.utils import weak_bind logger = init_logger(__name__) ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S @@ -450,9 +453,6 @@ class AsyncLLMEngine: method yields the outputs from the :class:`LLMEngine` to the caller. Args: - worker_use_ray: Whether to use Ray for model workers. Required for - distributed execution. Should be the same as - `parallel_config.worker_use_ray`. log_requests: Whether to log the requests. start_engine_loop: If True, the background task to run the engine will be automatically started in the generate call. @@ -463,23 +463,22 @@ class AsyncLLMEngine: _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine def __init__(self, - worker_use_ray: bool, *args, log_requests: bool = True, start_engine_loop: bool = True, **kwargs) -> None: - self.worker_use_ray = worker_use_ray self.log_requests = log_requests self.engine = self._engine_class(*args, **kwargs) # This ensures quick processing of request outputs # so the append to asyncio queues is not delayed, # especially for multi-step. - # - self.use_process_request_outputs_callback = True + self.use_process_request_outputs_callback = ( + self.engine.model_config.use_async_output_proc) + if self.use_process_request_outputs_callback: self.engine.process_request_outputs_callback = \ - self.process_request_outputs + weak_bind(self.process_request_outputs) self.background_loop: Optional[asyncio.Future] = None # We need to keep a reference to unshielded @@ -492,6 +491,11 @@ def __init__(self, # Lazy initialized fields self._request_tracker: RequestTracker + def __del__(self): + if rt := getattr(self, "request_tracker", None): + # Wake up engine loop so that it will exit cleanly + rt.new_requests_event.set() + @classmethod def _get_executor_cls( cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]: @@ -502,15 +506,12 @@ def _get_executor_cls( raise TypeError( "distributed_executor_backend must be a subclass of " f"ExecutorAsyncBase. Got {distributed_executor_backend}.") - if distributed_executor_backend.uses_ray: # type: ignore - initialize_ray_cluster(engine_config.parallel_config) executor_class = distributed_executor_backend elif engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutorAsync executor_class = NeuronExecutorAsync elif engine_config.device_config.device_type == "tpu": if distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync executor_class = RayTPUExecutorAsync else: @@ -531,11 +532,9 @@ def _get_executor_cls( from vllm.executor.xpu_executor import XPUExecutorAsync executor_class = XPUExecutorAsync elif distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync executor_class = RayXPUExecutorAsync elif distributed_executor_backend == "mp": - initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.multiproc_xpu_executor import ( MultiprocessingXPUExecutorAsync) executor_class = MultiprocessingXPUExecutorAsync @@ -543,7 +542,6 @@ def _get_executor_cls( raise RuntimeError( "Not supported distributed execution model on XPU device.") elif distributed_executor_backend == "ray": - initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync executor_class = RayGPUExecutorAsync elif distributed_executor_backend == "mp": @@ -559,19 +557,23 @@ def _get_executor_cls( def from_engine_args( cls, engine_args: AsyncEngineArgs, + engine_config: Optional[EngineConfig] = None, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, ) -> "AsyncLLMEngine": """Creates an async LLM engine from the engine arguments.""" # Create the engine configs. - engine_config = engine_args.create_engine_config() + if engine_config is None: + engine_config = engine_args.create_engine_config() executor_class = cls._get_executor_cls(engine_config) + if executor_class.uses_ray: + initialize_ray_cluster(engine_config.parallel_config) + # Create the async LLM engine. engine = cls( - executor_class.uses_ray, **engine_config.to_dict(), executor_class=executor_class, log_requests=not engine_args.disable_log_requests, @@ -628,7 +630,7 @@ def start_background_loop(self) -> None: self._request_tracker = RequestTracker() self._background_loop_unshielded = asyncio.get_event_loop( - ).create_task(self.run_engine_loop()) + ).create_task(self.run_engine_loop(weakref.ref(self))) self._background_loop_unshielded.add_done_callback( partial(_log_task_completion, error_callback=self._error_callback)) self.background_loop = asyncio.shield(self._background_loop_unshielded) @@ -698,9 +700,16 @@ def process_request_outputs(self, request_outputs) -> bool: async def _engine_abort(self, request_ids: Iterable[str]): self.engine.abort_request(request_ids) - async def run_engine_loop(self): + @staticmethod + async def run_engine_loop(engine_ref: ReferenceType): + """We use a weakref to the engine so that the running loop + doesn't prevent the engine being garbage collected.""" + engine: Optional["AsyncLLMEngine"] = engine_ref() + if not engine: + return + pipeline_parallel_size = \ - self.engine.parallel_config.pipeline_parallel_size + engine.engine.parallel_config.pipeline_parallel_size has_requests_in_progress = [False] * pipeline_parallel_size while True: if not any(has_requests_in_progress): @@ -711,11 +720,21 @@ async def run_engine_loop(self): # timeout, and unblocks the RPC thread in the workers so that # they can process any other queued control plane messages, # such as add/remove lora adapters. - await self.engine.stop_remote_worker_execution_loop_async() - await self._request_tracker.wait_for_new_requests() + await engine.engine.stop_remote_worker_execution_loop_async() + request_tracker = engine._request_tracker + # Allow engine to be garbage collected while + # waiting for new requests + del engine + await asyncio.sleep(0) + if engine_ref() is None: + return + await request_tracker.wait_for_new_requests() + engine = engine_ref() + if not engine: + return logger.debug("Got new requests!") requests_in_progress = [ - asyncio.create_task(self.engine_step(ve)) + asyncio.create_task(engine.engine_step(ve)) for ve in range(pipeline_parallel_size) ] has_requests_in_progress = [True] * pipeline_parallel_size @@ -733,19 +752,20 @@ async def run_engine_loop(self): result = task.result() virtual_engine = requests_in_progress.index(task) has_unfinished_requests = ( - self.engine.has_unfinished_requests_for_virtual_engine( + engine.engine. + has_unfinished_requests_for_virtual_engine( virtual_engine)) if result or has_unfinished_requests: requests_in_progress[virtual_engine] = ( asyncio.create_task( - self.engine_step(virtual_engine))) + engine.engine_step(virtual_engine))) has_requests_in_progress[virtual_engine] = True else: has_requests_in_progress[virtual_engine] = False except asyncio.TimeoutError as exc: logger.error( "Engine iteration timed out. This should never happen!") - self.set_errored(exc) + engine.set_errored(exc) raise await asyncio.sleep(0) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index dfdbc22ef00e1..8b5009b2c6668 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,8 +1,8 @@ -import functools import time from collections import deque from contextlib import contextmanager from dataclasses import dataclass +from functools import partial from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, Iterable, List, Mapping, NamedTuple, Optional) from typing import Sequence as GenericSequence @@ -51,7 +51,7 @@ BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import Counter, Device +from vllm.utils import Counter, Device, weak_bind from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -382,11 +382,16 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: for _ in range(self.parallel_config.pipeline_parallel_size) ] - self.async_callbacks = [ - functools.partial(self._process_model_outputs, - ctx=self.scheduler_contexts[v_id]) - for v_id in range(self.parallel_config.pipeline_parallel_size) - ] + if model_config.use_async_output_proc: + process_model_outputs = weak_bind(self._process_model_outputs) + + self.async_callbacks = [ + partial(process_model_outputs, + ctx=self.scheduler_contexts[v_id]) + for v_id in range(self.parallel_config.pipeline_parallel_size) + ] + else: + self.async_callbacks = [] # Currently used by AsyncLLMEngine to ensure quick append # of request outputs to asyncio queues @@ -869,8 +874,8 @@ def has_unfinished_requests_for_virtual_engine( """ return self.scheduler[virtual_engine].has_unfinished_seqs() + @staticmethod def _process_sequence_group_outputs( - self, seq_group: SequenceGroup, outputs: List[EmbeddingSequenceGroupOutput], ) -> None: diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 3598872b65bb0..47d227010c075 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -1,21 +1,20 @@ import asyncio import signal from http import HTTPStatus -from typing import Any +from typing import Any, Optional import uvicorn -from fastapi import FastAPI, Response +from fastapi import FastAPI, Request, Response from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError -from vllm.engine.protocol import AsyncEngineClient from vllm.logger import init_logger from vllm.utils import find_process_using_port logger = init_logger(__name__) -async def serve_http(app: FastAPI, engine: AsyncEngineClient, +async def serve_http(app: FastAPI, limit_concurrency: Optional[int], **uvicorn_kwargs: Any): logger.info("Available routes are:") for route in app.routes: @@ -29,16 +28,16 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient, # Set concurrency limits in uvicorn if running in multiprocessing mode # since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536). - if engine.limit_concurrency is not None: + if limit_concurrency is not None: logger.info( "Launching Uvicorn with --limit_concurrency %s. To avoid this " "limit at the expense of performance run with " - "--disable-frontend-multiprocessing", engine.limit_concurrency) - uvicorn_kwargs["limit_concurrency"] = engine.limit_concurrency + "--disable-frontend-multiprocessing", limit_concurrency) + uvicorn_kwargs["limit_concurrency"] = limit_concurrency config = uvicorn.Config(app, **uvicorn_kwargs) server = uvicorn.Server(config) - _add_shutdown_handlers(app, server, engine) + _add_shutdown_handlers(app, server) loop = asyncio.get_running_loop() @@ -68,15 +67,15 @@ async def dummy_shutdown() -> None: return server.shutdown() -def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server, - engine: AsyncEngineClient) -> None: +def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: """Adds handlers for fatal errors that should crash the server""" @app.exception_handler(RuntimeError) - async def runtime_error_handler(_, __): + async def runtime_error_handler(request: Request, __): """On generic runtime error, check to see if the engine has died. It probably has, in which case the server will no longer be able to handle requests. Trigger a graceful shutdown with a SIGTERM.""" + engine = request.app.state.engine_client if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored and not engine.is_running): logger.fatal("AsyncLLMEngine has failed, terminating server " diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 7c1f307e06619..b50fc6a265f8d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -4,16 +4,20 @@ import multiprocessing import os import re +import signal import tempfile from argparse import Namespace from contextlib import asynccontextmanager +from functools import partial from http import HTTPStatus from typing import AsyncIterator, Optional, Set +import uvloop from fastapi import APIRouter, FastAPI, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse +from starlette.datastructures import State from starlette.routing import Mount from typing_extensions import assert_never @@ -54,12 +58,6 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds -async_engine_client: AsyncEngineClient -engine_args: AsyncEngineArgs -openai_serving_chat: OpenAIServingChat -openai_serving_completion: OpenAIServingCompletion -openai_serving_embedding: OpenAIServingEmbedding -openai_serving_tokenization: OpenAIServingTokenization prometheus_multiproc_dir: tempfile.TemporaryDirectory # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) @@ -83,18 +81,28 @@ def model_is_embedding(model_name: str, trust_remote_code: bool, @asynccontextmanager async def lifespan(app: FastAPI): - - async def _force_log(): - while True: - await asyncio.sleep(10) - await async_engine_client.do_log_stats() - - if not engine_args.disable_log_stats: - task = asyncio.create_task(_force_log()) - _running_tasks.add(task) - task.add_done_callback(_running_tasks.remove) - - yield + try: + if app.state.log_stats: + async_engine_client = app.state.engine_client + + async def _force_log(): + while True: + await asyncio.sleep(10) + await async_engine_client.do_log_stats() + + task = asyncio.create_task(_force_log()) + _running_tasks.add(task) + task.add_done_callback(_running_tasks.remove) + else: + task = None + try: + yield + finally: + if task is not None: + task.cancel() + finally: + # Ensure app state including engine ref is gc'd + del app.state @asynccontextmanager @@ -103,16 +111,10 @@ async def build_async_engine_client( # Context manager to handle async_engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit - global engine_args engine_args = AsyncEngineArgs.from_cli_args(args) - # Backend itself still global for the silly lil' health handler - global async_engine_client - async with build_async_engine_client_from_engine_args( engine_args, args.disable_frontend_multiprocessing) as engine: - - async_engine_client = engine # type: ignore[assignment] yield engine @@ -134,12 +136,22 @@ async def build_async_engine_client_from_engine_args( if (model_is_embedding(engine_args.model, engine_args.trust_remote_code, engine_args.quantization, engine_args.revision) or disable_frontend_multiprocessing): - engine_client = AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.OPENAI_API_SERVER) - try: - yield engine_client - finally: - engine_client.shutdown_background_loop() + engine_config = engine_args.create_engine_config() + uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config), + "uses_ray", False) + + build_engine = partial(AsyncLLMEngine.from_engine_args, + engine_args=engine_args, + engine_config=engine_config, + usage_context=UsageContext.OPENAI_API_SERVER) + if uses_ray: + # Must run in main thread with ray for its signal handlers to work + engine_client = build_engine() + else: + engine_client = await asyncio.get_running_loop().run_in_executor( + None, build_engine) + + yield engine_client return # Otherwise, use the multiprocessing AsyncLLMEngine. @@ -241,16 +253,36 @@ def mount_metrics(app: FastAPI): app.routes.append(metrics_route) +def chat(request: Request) -> OpenAIServingChat: + return request.app.state.openai_serving_chat + + +def completion(request: Request) -> OpenAIServingCompletion: + return request.app.state.openai_serving_completion + + +def tokenization(request: Request) -> OpenAIServingTokenization: + return request.app.state.openai_serving_tokenization + + +def embedding(request: Request) -> OpenAIServingEmbedding: + return request.app.state.openai_serving_embedding + + +def engine_client(request: Request) -> AsyncEngineClient: + return request.app.state.engine_client + + @router.get("/health") -async def health() -> Response: +async def health(raw_request: Request) -> Response: """Health check.""" - await async_engine_client.check_health() + await engine_client(raw_request).check_health() return Response(status_code=200) @router.post("/tokenize") -async def tokenize(request: TokenizeRequest): - generator = await openai_serving_tokenization.create_tokenize(request) +async def tokenize(request: TokenizeRequest, raw_request: Request): + generator = await tokenization(raw_request).create_tokenize(request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -261,8 +293,8 @@ async def tokenize(request: TokenizeRequest): @router.post("/detokenize") -async def detokenize(request: DetokenizeRequest): - generator = await openai_serving_tokenization.create_detokenize(request) +async def detokenize(request: DetokenizeRequest, raw_request: Request): + generator = await tokenization(raw_request).create_detokenize(request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -273,8 +305,8 @@ async def detokenize(request: DetokenizeRequest): @router.get("/v1/models") -async def show_available_models(): - models = await openai_serving_completion.show_available_models() +async def show_available_models(raw_request: Request): + models = await completion(raw_request).show_available_models() return JSONResponse(content=models.model_dump()) @@ -288,7 +320,7 @@ async def show_version(): async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): - generator = await openai_serving_chat.create_chat_completion( + generator = await chat(raw_request).create_chat_completion( request, raw_request) if isinstance(generator, ErrorResponse): @@ -303,7 +335,7 @@ async def create_chat_completion(request: ChatCompletionRequest, @router.post("/v1/completions") async def create_completion(request: CompletionRequest, raw_request: Request): - generator = await openai_serving_completion.create_completion( + generator = await completion(raw_request).create_completion( request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -316,7 +348,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): @router.post("/v1/embeddings") async def create_embedding(request: EmbeddingRequest, raw_request: Request): - generator = await openai_serving_embedding.create_embedding( + generator = await embedding(raw_request).create_embedding( request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), @@ -333,16 +365,16 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): "used for local development!") @router.post("/start_profile") - async def start_profile(): + async def start_profile(raw_request: Request): logger.info("Starting profiler...") - await async_engine_client.start_profile() + await engine_client(raw_request).start_profile() logger.info("Profiler started.") return Response(status_code=200) @router.post("/stop_profile") - async def stop_profile(): + async def stop_profile(raw_request: Request): logger.info("Stopping profiler...") - await async_engine_client.stop_profile() + await engine_client(raw_request).stop_profile() logger.info("Profiler stopped.") return Response(status_code=200) @@ -353,13 +385,14 @@ async def stop_profile(): "This should ONLY be used for local development!") @router.post("/v1/load_lora_adapter") - async def load_lora_adapter(request: LoadLoraAdapterRequest): - response = await openai_serving_chat.load_lora_adapter(request) + async def load_lora_adapter(request: LoadLoraAdapterRequest, + raw_request: Request): + response = await chat(raw_request).load_lora_adapter(request) if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), status_code=response.code) - response = await openai_serving_completion.load_lora_adapter(request) + response = await completion(raw_request).load_lora_adapter(request) if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), status_code=response.code) @@ -367,13 +400,14 @@ async def load_lora_adapter(request: LoadLoraAdapterRequest): return Response(status_code=200, content=response) @router.post("/v1/unload_lora_adapter") - async def unload_lora_adapter(request: UnloadLoraAdapterRequest): - response = await openai_serving_chat.unload_lora_adapter(request) + async def unload_lora_adapter(request: UnloadLoraAdapterRequest, + raw_request: Request): + response = await chat(raw_request).unload_lora_adapter(request) if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), status_code=response.code) - response = await openai_serving_completion.unload_lora_adapter(request) + response = await completion(raw_request).unload_lora_adapter(request) if isinstance(response, ErrorResponse): return JSONResponse(content=response.model_dump(), status_code=response.code) @@ -398,7 +432,8 @@ def build_app(args: Namespace) -> FastAPI: @app.exception_handler(RequestValidationError) async def validation_exception_handler(_, exc): - err = openai_serving_chat.create_error_response(message=str(exc)) + chat = app.state.openai_serving_chat + err = chat.create_error_response(message=str(exc)) return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) @@ -430,30 +465,26 @@ async def authentication(request: Request, call_next): return app -async def init_app( +def init_app_state( async_engine_client: AsyncEngineClient, + model_config: ModelConfig, + state: State, args: Namespace, -) -> FastAPI: - app = build_app(args) - +) -> None: if args.served_model_name is not None: served_model_names = args.served_model_name else: served_model_names = [args.model] - model_config = await async_engine_client.get_model_config() - if args.disable_log_requests: request_logger = None else: request_logger = RequestLogger(max_log_len=args.max_log_len) - global openai_serving_chat - global openai_serving_completion - global openai_serving_embedding - global openai_serving_tokenization + state.engine_client = async_engine_client + state.log_stats = not args.disable_log_stats - openai_serving_chat = OpenAIServingChat( + state.openai_serving_chat = OpenAIServingChat( async_engine_client, model_config, served_model_names, @@ -465,7 +496,7 @@ async def init_app( return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_auto_tools=args.enable_auto_tool_choice, tool_parser=args.tool_call_parser) - openai_serving_completion = OpenAIServingCompletion( + state.openai_serving_completion = OpenAIServingCompletion( async_engine_client, model_config, served_model_names, @@ -474,13 +505,13 @@ async def init_app( request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) - openai_serving_embedding = OpenAIServingEmbedding( + state.openai_serving_embedding = OpenAIServingEmbedding( async_engine_client, model_config, served_model_names, request_logger=request_logger, ) - openai_serving_tokenization = OpenAIServingTokenization( + state.openai_serving_tokenization = OpenAIServingTokenization( async_engine_client, model_config, served_model_names, @@ -488,25 +519,31 @@ async def init_app( request_logger=request_logger, chat_template=args.chat_template, ) - app.root_path = args.root_path - - return app async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("terminated") + + signal.signal(signal.SIGTERM, signal_handler) + async with build_async_engine_client(args) as async_engine_client: # If None, creation of the client failed and we exit. if async_engine_client is None: return - app = await init_app(async_engine_client, args) + app = build_app(args) + + model_config = await async_engine_client.get_model_config() + init_app_state(async_engine_client, model_config, app.state, args) shutdown_task = await serve_http( app, - engine=async_engine_client, + limit_concurrency=async_engine_client.limit_concurrency, host=args.host, port=args.port, log_level=args.uvicorn_log_level, @@ -530,4 +567,4 @@ async def run_server(args, **uvicorn_kwargs) -> None: parser = make_arg_parser(parser) args = parser.parse_args() - asyncio.run(run_server(args)) + uvloop.run(run_server(args)) diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index bebc2faedb680..460ff0636b6e9 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -46,7 +46,6 @@ def cleanup(self): """Cleanup all resources.""" self.socket.close() self.context.destroy() - self.engine.shutdown_background_loop() # Clear the engine reference so that it can be GC'ed. del self.engine @@ -233,5 +232,12 @@ def signal_handler() -> None: def run_rpc_server(async_engine_args: AsyncEngineArgs, usage_context: UsageContext, rpc_path: str): + + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("AsyncEngineRPCServer terminated") + + signal.signal(signal.SIGTERM, signal_handler) + server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path) uvloop.run(run_server(server)) diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 9c6d4051eb3f8..cc535e99a06ef 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -1,8 +1,5 @@ import asyncio import os -import signal -import threading -import weakref from functools import partial from typing import Any, List, Optional @@ -108,17 +105,6 @@ def _init_executor(self) -> None: # Set up signal handlers to shutdown the executor cleanly # sometimes gc does not work well - # Use weakref to avoid holding a reference to self - ref = weakref.ref(self) - - def shutdown(signum, frame): - if executor := ref(): - executor.shutdown() - - if threading.current_thread() is threading.main_thread(): - signal.signal(signal.SIGINT, shutdown) - signal.signal(signal.SIGTERM, shutdown) - self.driver_worker = self._create_worker( distributed_init_method=distributed_init_method) self._run_workers("init_device") diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py index 28c8e8699f083..aa2a16c04d08d 100644 --- a/vllm/executor/multiproc_worker_utils.py +++ b/vllm/executor/multiproc_worker_utils.py @@ -120,7 +120,8 @@ def run(self) -> None: logger.error("Worker %s pid %s died, exit code: %s", process.name, process.pid, process.exitcode) # Cleanup any remaining workers - logger.info("Killing local vLLM worker processes") + if logger: + logger.info("Killing local vLLM worker processes") for worker in self.workers: worker.kill_worker() # Must be done after worker task queues are all closed @@ -221,6 +222,8 @@ def _run_worker_process( try: executor = getattr(worker, method) output = executor(*args, **kwargs) + except KeyboardInterrupt: + break except BaseException as e: tb = traceback.format_exc() logger.error( diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index 732b69d6e5954..d02fecb46f007 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -26,6 +26,8 @@ class RayTPUExecutor(TPUExecutor): + uses_ray: bool = True + def __init__(self, *args, **kwargs): # This is non-None when the execute model loop is running # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. diff --git a/vllm/scripts.py b/vllm/scripts.py index e557961a335bf..231a18e99f3d7 100644 --- a/vllm/scripts.py +++ b/vllm/scripts.py @@ -1,11 +1,11 @@ # The CLI entrypoint to vLLM. import argparse -import asyncio import os import signal import sys from typing import List, Optional +import uvloop from openai import OpenAI from openai.types.chat import ChatCompletionMessageParam @@ -34,7 +34,7 @@ def serve(args: argparse.Namespace) -> None: # EngineArgs expects the model name to be passed as --model. args.model = args.model_tag - asyncio.run(run_server(args)) + uvloop.run(run_server(args)) def interactive_cli(args: argparse.Namespace) -> None: diff --git a/vllm/utils.py b/vllm/utils.py index aba243071b69a..014fc16a17c1f 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -12,6 +12,7 @@ import threading import uuid import warnings +import weakref from asyncio import FIRST_COMPLETED, ensure_future from functools import lru_cache, partial, wraps from platform import uname @@ -1079,6 +1080,20 @@ def cuda_device_count_stateless() -> int: return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) +def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]: + """Make an instance method that weakly references + its associated instance and no-ops once that + instance is collected.""" + ref = weakref.ref(bound_method.__self__) # type: ignore[attr-defined] + unbound = bound_method.__func__ # type: ignore[attr-defined] + + def weak_bound(*args, **kwargs) -> None: + if inst := ref(): + unbound(inst, *args, **kwargs) + + return weak_bound + + #From: https://stackoverflow.com/a/4104188/2749989 def run_once(f: Callable[P, None]) -> Callable[P, None]: