diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 59d7241bd452d..aa90145705f9d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -86,6 +86,7 @@ steps: - vllm/ commands: - pip install -e ./plugins/vllm_add_dummy_model + - pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@a4987bba6e9e9b3f22bd3a6c1ecf0abd04fd5622#egg=lm_eval[api] - pytest -v -s entrypoints/llm - pytest -v -s entrypoints/openai diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py new file mode 100644 index 0000000000000..b442a903c33ae --- /dev/null +++ b/tests/entrypoints/openai/test_accuracy.py @@ -0,0 +1,55 @@ +""" +This file test accuracy of the vLLM server via LMEval. +It uses local-completions, which interacts with vLLM +through the OAI API with N concurrent connections. +This simulates real work usage of the API and makes +sure that the zmq frontend mp RPC message passing and +AsyncLLMEngine are working correctly. +""" + +import lm_eval +import pytest + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct" +NUM_CONCURRENT = 500 +TASK = "gsm8k" +FILTER = "exact_match,strict-match" +RTOL = 0.03 +EXPECTED_VALUE = 0.58 + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--max-model-len", "4096", "--enable-chunked-prefill", + "--disable-log-requests", "--enforce-eager" + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def server_data(server): + return { + "url": f"{server.url_for('v1')}/completions", + } + + +def test_lm_eval_accuracy(server_data): + model_args = (f"model={MODEL_NAME}," + f"base_url={server_data['url']}," + f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False") + + results = lm_eval.simple_evaluate( + model="local-completions", + model_args=model_args, + tasks=TASK, + ) + + measured_value = results["results"][TASK][FILTER] + assert (measured_value - RTOL < EXPECTED_VALUE + and measured_value + RTOL > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index ceda0b83a2397..9911cc9bdd84f 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -766,6 +766,11 @@ def is_stopped(self) -> bool: def errored(self) -> bool: return self._errored_with is not None + @property + def limit_concurrency(self) -> Optional[int]: + """Maximum number of concurrently running requests.""" + return None + def set_errored(self, exc: Exception) -> None: self._errored_with = exc diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index cb16775a1cd59..6c7fd96a7f8e5 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -29,6 +29,10 @@ def is_stopped(self) -> bool: def errored(self) -> bool: ... + @property + def limit_concurrency(self) -> Optional[int]: + """Maximum number of concurrently running requests.""" + def generate( self, inputs: PromptInputs, diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index f4a9c61a431c1..3598872b65bb0 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -27,6 +27,15 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient, logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) + # Set concurrency limits in uvicorn if running in multiprocessing mode + # since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536). + if engine.limit_concurrency is not None: + logger.info( + "Launching Uvicorn with --limit_concurrency %s. To avoid this " + "limit at the expense of performance run with " + "--disable-frontend-multiprocessing", engine.limit_concurrency) + uvicorn_kwargs["limit_concurrency"] = engine.limit_concurrency + config = uvicorn.Config(app, **uvicorn_kwargs) server = uvicorn.Server(config) _add_shutdown_handlers(app, server, engine) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index f37c7f4d91d57..266bf79dcdd65 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -135,6 +135,12 @@ async def build_async_engine_client( logger.info("Multiprocessing frontend to use %s for RPC Path.", rpc_path) + # Build RPCClient, which conforms to AsyncEngineClient Protocol. + # NOTE: Actually, this is not true yet. We still need to support + # embedding models via RPC (see TODO above) + rpc_client = AsyncEngineRPCClient(rpc_path) + async_engine_client = rpc_client # type: ignore + # Start RPCServer in separate process (holds the AsyncLLMEngine). context = multiprocessing.get_context("spawn") # the current process might have CUDA context, @@ -145,11 +151,6 @@ async def build_async_engine_client( rpc_server_process.start() logger.info("Started engine process with PID %d", rpc_server_process.pid) - # Build RPCClient, which conforms to AsyncEngineClient Protocol. - # NOTE: Actually, this is not true yet. We still need to support - # embedding models via RPC (see TODO above) - rpc_client = AsyncEngineRPCClient(rpc_path) - async_engine_client = rpc_client # type: ignore try: while True: diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index 8a7b12201cab7..981dfbfc6670e 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -7,8 +7,18 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams +# Success string used for RPC instructions. VLLM_RPC_SUCCESS_STR = "SUCCESS" -VLLM_RPC_HEALTHY_STR = "HEALTHY" + +# Timeouts. +VLLM_RPC_SERVER_START_TIMEOUT_MS = 1000 +VLLM_RPC_HEALTH_TIMEOUT_MS = 10000 + +# Minimum value of ZMQ.SOCKET_LIMIT to run mp. +VLLM_RPC_SOCKET_LIMIT_CUTOFF = 2000 + +# HWM is set to Infinity. +VLLM_RPC_ZMQ_HWM = 0 @dataclass @@ -34,7 +44,7 @@ class RPCUtilityRequest(Enum): GET_SCHEDULER_CONFIG = 5 GET_LORA_CONFIG = 6 DO_LOG_STATS = 7 - CHECK_HEALTH = 8 + IS_SERVER_HEALTHY = 8 IS_TRACING_ENABLED = 9 diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 64a20b33d8f3e..7e360d1defb10 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -1,5 +1,7 @@ +import asyncio from contextlib import contextmanager from typing import Any, AsyncGenerator, Mapping, Optional +from uuid import uuid4 import cloudpickle import zmq @@ -7,32 +9,140 @@ from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) +# yapf: disable from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, - VLLM_RPC_HEALTHY_STR, - VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + VLLM_RPC_HEALTH_TIMEOUT_MS, + VLLM_RPC_SERVER_START_TIMEOUT_MS, + VLLM_RPC_SOCKET_LIMIT_CUTOFF, + VLLM_RPC_SUCCESS_STR, + VLLM_RPC_ZMQ_HWM, RPCAbortRequest, RPCGenerateRequest, RPCUtilityRequest) +# yapf: enable from vllm.inputs import PromptInputs +from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -# Time to wait before checking it the server process is alive. -SERVER_START_TIMEOUT_MS = 1000 +logger = init_logger(__name__) + +# Path used for inprocess proxy. +INPROC_PROXY_PATH = f"inproc://{uuid4()}" class AsyncEngineRPCClient: + """ + RPCClient that connects to the RPCServer wrapping AsyncLLMEngine. + + The overall design mirrors the Asynchronous Client Server Pattern + https://zguide.zeromq.org/docs/chapter3/#The-Asynchronous-Client-Server-Pattern + + On startup, the RPCClient: + - makes DEALER socket (to_rpc_server) that connects to the RPCServer + via ipc, which uses unix sockets under the hood + (https://libzmq.readthedocs.io/en/zeromq4-1/zmq_ipc.html) + - makes ROUTER socket (from_api_server) that binds to a random + inproc address, which uses memory under the hood + (https://libzmq.readthedocs.io/en/zeromq3-x/zmq_inproc.html) + - runs a proxy in a background asyncio task between + from_api_server (ROUTER, inproc) and to_rpc_server (DEALER ipc, ) + + Each request handled by the asyncio api_server calls generate(): + - make a DEALER socket that connects to from_api_server via inproc + - send a RCPGenerateRequest to the inproc socket + - background proxy forwards the request from inproc -> ipc + - RPCServer responds to the request one token at a time over ipc + - background proxy forwards the response from ipc -> inproc + + The connection looks like this: + DEALER <- inproc -> [ ROUTER | DEALER ] <- ipc -> DEALER + + Message routing is performed via identities that are managed by the + ROUTER socket. ROUTER sockets track every connection it has and + tells the caller about these. The way it tells the caller is to stick + the connection identity in front of each message received. When we + send the message via a ROUTER, we first send an identity frame. + See https://zguide.zeromq.org/docs/chapter3/#The-Extended-Reply-Envelope + for more details on connection identities. + + This proxy design enables us to use a single unix socket, which + improves performance by avoiding syscalls (~5%) and avoids resource limits + such as ulimit, which defaults to 1024 on ubuntu. + + Note: we run set_hwm(0) on each socket, which sets the HWM to inf, + which is required to avoid dropping messages under high load. + This is generally not advisable. However, since we are in control + of both sides of the connection + failure on either side is + catastrophic to the overall system health and memory profiling + suggests limited memory overhead relative to asyncio, we will + proceed for now. + + See https://zguide.zeromq.org/docs/chapter2/#High-Water-Marks + for more details on high water marks. + """ def __init__(self, rpc_path: str): self.context = zmq.asyncio.Context() - self.rpc_path = rpc_path + + # Maximum number of sockets that can be opened (typically 65536). + # ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get) + socket_limit = self.context.get(zmq.constants.SOCKET_LIMIT) + if socket_limit < VLLM_RPC_SOCKET_LIMIT_CUTOFF: + raise ValueError( + f"Found zmq.constants.SOCKET_LIMIT={socket_limit}, which caps " + "the number of concurrent requests vLLM can process. Launch " + "vLLM with --disable-frontend-multiprocessing and open a " + "GitHub issue so we can investigate.") + + # We only have 1 ipc connection that uses unix sockets, so + # safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will + # not run into ulimit issues) + self.context.set(zmq.constants.MAX_SOCKETS, socket_limit) + + # IPC connection to RPC Server (uses unix sockets). + self.to_rpc_server = self.context.socket(zmq.constants.DEALER) + self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM) + self.to_rpc_server.bind(rpc_path) + + # In process proxy to RPC Server (uses memory-based messaging). + self.from_api_server = self.context.socket(zmq.constants.ROUTER) + self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM) + self.from_api_server.bind(INPROC_PROXY_PATH) + + # Asyncio background task for the proxy. + self.proxy_task = asyncio.create_task( + self.run_proxy(self.from_api_server, self.to_rpc_server)) + + # Since we open 1 inproc socket per request, we have a hard cap on + # the number of requests that can run in vLLM w. frontend + # mulitprocessing. This value is used uvicorn to launch + # with --limit-concurrency to return 503 when server is overloaded. + # We need 2 sockets per request - 2: + # 1 for generate(), 1 for abort(), do_log_stats(), check_health() + self.limit_concurrency = socket_limit // 2 - 2 + + async def run_proxy(self, socket_from, socket_to): + """Background task that runs a proxy""" + poller = zmq.asyncio.Poller() + poller.register(socket_from, zmq.constants.POLLIN) + poller.register(socket_to, zmq.constants.POLLIN) + while True: + events = await poller.poll() + events = dict(events) + if socket_from in events: + identity, msg = await socket_from.recv_multipart() + await socket_to.send_multipart([identity, msg]) + if socket_to in events: + identity, msg = await socket_to.recv_multipart() + await socket_from.send_multipart([identity, msg]) async def setup(self): """Setup the client before it starts sending server requests.""" # Wait until server is ready. - await self.wait_for_server() + await self._wait_for_server_rpc() self._errored = False # Get the configs. @@ -51,29 +161,23 @@ async def setup(self): def close(self): """Destroy the ZeroMQ Context.""" + # Close all sockets associated with this context and + # then terminate the context. + self.from_api_server.close() + self.to_rpc_server.close() self.context.destroy() @contextmanager - def socket(self): - # Ensure client sockets are always closed after use - - # Connect to RPC socket for Request-Reply pattern, + def to_proxy_socket(self): + # Connect to the RPCServer via the proxy. # Note that we use DEALER to enable asynchronous communication # to enable streaming. socket = self.context.socket(zmq.constants.DEALER) + socket.set_hwm(VLLM_RPC_ZMQ_HWM) try: - socket.connect(self.rpc_path) + socket.connect(INPROC_PROXY_PATH) yield socket finally: - # linger == 0 means discard unsent messages - # when the socket is closed. This is necessary - # because otherwise self.context.destroy() will - # wait for 30 seconds until unsent messages are - # received, which is impossible if the server - # crashed. In the absence of a server crash we - # always expect a response before closing the - # socket anyway. - # Reference: http://api.zeromq.org/4-2:zmq-setsockopt#toc24 socket.close(linger=0) async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, @@ -81,10 +185,9 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, error_message: str) -> Any: """Send an RPC request that is expecting data back.""" - with self.socket() as socket: - + with self.to_proxy_socket() as socket: # Ping RPCServer with a request. - await socket.send(cloudpickle.dumps(request)) + await socket.send_multipart([cloudpickle.dumps(request)]) # Await the data from the Server. data = cloudpickle.loads(await socket.recv()) @@ -93,31 +196,48 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, # LoRAConfig can be None. if expected_type == LoRAConfig and data is None: pass + elif isinstance(data, Exception): + logger.error(error_message) + raise data else: raise ValueError(error_message) return data - async def _send_one_way_rpc_request(self, - request: RPC_REQUEST_TYPE, - error_message: str, - timeout: Optional[int] = None): + async def _send_one_way_rpc_request( + self, + request: RPC_REQUEST_TYPE, + error_message: str, + timeout: Optional[int] = None, + socket: Optional[zmq.asyncio.Socket] = None): """Send one-way RPC request to trigger an action.""" - with self.socket() as socket: - # Ping RPC Server with request. - await socket.send(cloudpickle.dumps(request)) - # Await acknowledgement from RPCServer. + async def do_rpc_call(socket: zmq.asyncio.Socket, + request: RPC_REQUEST_TYPE, + timeout=None): + + await socket.send_multipart([cloudpickle.dumps(request)]) + if timeout is not None and await socket.poll(timeout=timeout) == 0: - raise TimeoutError(f"server didn't reply within {timeout} ms") + raise TimeoutError(f"Server didn't reply within {timeout} ms") + + return cloudpickle.loads(await socket.recv()) - response = cloudpickle.loads(await socket.recv()) + # Make a new socket connection. + if socket is None: + with self.to_proxy_socket() as socket: + response = await do_rpc_call(socket, request, timeout) + + # Use existing socket connection. + else: + response = await do_rpc_call(socket, request, timeout) if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: + if isinstance(response, Exception): + logger.error(error_message) + raise response raise ValueError(error_message) - return response - async def get_tokenizer(self, lora_request: LoRARequest): return await self.tokenizer.get_lora_tokenizer_async(lora_request) @@ -130,13 +250,13 @@ async def get_model_config(self) -> ModelConfig: async def is_tracing_enabled(self) -> bool: return self.tracing_flag - async def wait_for_server(self): + async def _wait_for_server_rpc(self): """Wait for the RPCServer to start up.""" await self._send_one_way_rpc_request( request=RPCUtilityRequest.IS_SERVER_READY, - error_message="Unable to start RPC Server.", - timeout=SERVER_START_TIMEOUT_MS) + error_message="Unable to start RPC Server", + timeout=VLLM_RPC_SERVER_START_TIMEOUT_MS) async def _get_model_config_rpc(self) -> ModelConfig: """Get the ModelConfig object from the RPC Server""" @@ -184,8 +304,7 @@ async def _is_tracing_enabled_rpc(self) -> bool: return await self._send_get_data_rpc_request( RPCUtilityRequest.IS_TRACING_ENABLED, expected_type=bool, - error_message="Could not get is_tracing_enabled flag from RPC " - "Server") + error_message="Could not get is_tracing_enabled from RPC Server") async def abort(self, request_id: str): """Send an ABORT_REQUEST signal to the RPC Server""" @@ -226,8 +345,7 @@ async def generate( finished = False try: - with self.socket() as socket: - + with self.to_proxy_socket() as socket: # Send RPCGenerateRequest to the RPCServer. await socket.send_multipart([ cloudpickle.dumps( @@ -246,43 +364,37 @@ async def generate( request_output = cloudpickle.loads(message) if isinstance(request_output, Exception): - # On exception, check if the server is still healthy. - # Use this to set the sync `is_running` and `errored` - # properties. - try: - await self.check_health() - except Exception: - self._errored = True + # On exception, check if the server is still healthy + # possibly setting the `errored` property. + if not self._errored: + try: + await self.check_health(socket=socket) + except Exception as e: + self._errored = True + logger.exception(repr(e)) + # NB: do before raising here so that the flag is set # by the time the caller receives this exception raise request_output finished = request_output.finished yield request_output + finally: - if not finished: + # Request was canceled by the client. + if not finished and not self._errored: await self.abort(request_id) - async def check_health(self) -> None: + async def check_health(self, + socket: Optional[zmq.asyncio.Socket] = None + ) -> None: """Raise if unhealthy""" - with self.socket() as socket: - - # Ping RPCServer with CHECK_HEALTH request. - await socket.send(cloudpickle.dumps(RPCUtilityRequest.CHECK_HEALTH) - ) - - # Await the reply from the server. - # TODO: do we need an internal timeout here? - # Or do we expect the external probe to timeout and let this chill? - health_message = cloudpickle.loads(await socket.recv()) - - if isinstance(health_message, Exception): - raise health_message - - if health_message != VLLM_RPC_HEALTHY_STR: - raise ValueError("Expected healthy response from backend but got " - "f{health_message}") + await self._send_one_way_rpc_request( + request=RPCUtilityRequest.IS_SERVER_HEALTHY, + error_message="Got Unhealthy response from RPC Server", + timeout=VLLM_RPC_HEALTH_TIMEOUT_MS, + socket=socket) async def encode(self, *args, **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 770ee77926df9..580b83277cfbe 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -1,6 +1,6 @@ import asyncio import signal -from typing import Any, Coroutine +from typing import Any, Coroutine, Union import cloudpickle import uvloop @@ -9,14 +9,19 @@ from typing_extensions import Never from vllm import AsyncEngineArgs, AsyncLLMEngine -from vllm.entrypoints.openai.rpc import (VLLM_RPC_HEALTHY_STR, - VLLM_RPC_SUCCESS_STR, RPCAbortRequest, +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +from vllm.entrypoints.openai.rpc import (VLLM_RPC_SUCCESS_STR, + VLLM_RPC_ZMQ_HWM, RPCAbortRequest, RPCGenerateRequest, RPCUtilityRequest) from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext logger = init_logger(__name__) +CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, + SchedulerConfig, LoRAConfig] + class AsyncEngineRPCServer: @@ -29,9 +34,10 @@ def __init__(self, async_engine_args: AsyncEngineArgs, # Initialize context. self.context = zmq.asyncio.Context() - # Init socket for readiness state. - self.socket = self.context.socket(zmq.constants.ROUTER) - self.socket.bind(rpc_path) + # Init socket. + self.socket = self.context.socket(zmq.constants.DEALER) + self.socket.set_hwm(VLLM_RPC_ZMQ_HWM) + self.socket.connect(rpc_path) def cleanup(self): """Cleanup all resources.""" @@ -41,39 +47,27 @@ def cleanup(self): # Clear the engine reference so that it can be GC'ed. del self.engine - async def get_model_config(self, identity): - """Send the ModelConfig""" - model_config = await self.engine.get_model_config() - - await self.socket.send_multipart( - [identity, cloudpickle.dumps(model_config)]) - - async def get_decoding_config(self, identity): - """Send the DecodingConfig""" - decoding_config = await self.engine.get_decoding_config() - - await self.socket.send_multipart( - [identity, cloudpickle.dumps(decoding_config)]) - - async def get_lora_config(self, identity): - lora_config = await self.engine.get_lora_config() - - await self.socket.send_multipart( - [identity, cloudpickle.dumps(lora_config)]) - - async def get_scheduler_config(self, identity): - """Send the SchedulerConfig""" - parallel_config = await self.engine.get_scheduler_config() - - await self.socket.send_multipart( - [identity, cloudpickle.dumps(parallel_config)]) + async def get_config(self, identity, request): + try: + config: CONFIG_TYPE + if request == RPCUtilityRequest.GET_MODEL_CONFIG: + config = await self.engine.get_model_config() + elif request == RPCUtilityRequest.GET_DECODING_CONFIG: + config = await self.engine.get_decoding_config() + elif request == RPCUtilityRequest.GET_LORA_CONFIG: + config = await self.engine.get_lora_config() + elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: + config = await self.engine.get_scheduler_config() + elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: + config = await self.engine.get_parallel_config() + else: + raise ValueError("Unknown Config Request: %s", request) - async def get_parallel_config(self, identity): - """Send the ParallelConfig""" - parallel_config = await self.engine.get_parallel_config() + await self.socket.send_multipart( + [identity, cloudpickle.dumps(config)]) - await self.socket.send_multipart( - [identity, cloudpickle.dumps(parallel_config)]) + except Exception as e: + await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) async def is_tracing_enabled(self, identity): """Send the is_tracing_enabled flag""" @@ -86,31 +80,23 @@ async def do_log_stats(self, identity): """Log stats and confirm success.""" await self.engine.do_log_stats() - await self.socket.send_multipart([ - identity, - cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), - ]) + await self.socket.send_multipart( + [identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)]) async def is_server_ready(self, identity): """Notify the client that we are ready.""" - await self.socket.send_multipart([ - identity, - cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), - ]) + await self.socket.send_multipart( + [identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)]) async def abort(self, identity, request: RPCAbortRequest): """Abort request and notify the client of success.""" try: # Abort the request in the llm engine. await self.engine.abort(request.request_id) - except Exception: - logger.warning("Failed to abort request %s", request.request_id) - finally: - # Send confirmation to the client. - await self.socket.send_multipart([ - identity, - cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), - ]) + result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR + except Exception as e: + result = e + await self.socket.send_multipart([identity, cloudpickle.dumps(result)]) async def generate(self, identity, generate_request: RPCGenerateRequest): try: @@ -127,14 +113,14 @@ async def generate(self, identity, generate_request: RPCGenerateRequest): [identity, cloudpickle.dumps(request_output)]) except Exception as e: - ### Notify client of all failures await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) async def check_health(self, identity): try: await self.engine.check_health() await self.socket.send_multipart( - [identity, cloudpickle.dumps(VLLM_RPC_HEALTHY_STR)]) + [identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)]) + except Exception as e: await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) @@ -151,21 +137,19 @@ def _make_handler_coro(self, identity, return self.abort(identity, request) elif isinstance(request, RPCUtilityRequest): - if request == RPCUtilityRequest.GET_MODEL_CONFIG: - return self.get_model_config(identity) - elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: - return self.get_parallel_config(identity) - elif request == RPCUtilityRequest.GET_DECODING_CONFIG: - return self.get_decoding_config(identity) - elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: - return self.get_scheduler_config(identity) - elif request == RPCUtilityRequest.GET_LORA_CONFIG: - return self.get_lora_config(identity) + if request in [ + RPCUtilityRequest.GET_MODEL_CONFIG, + RPCUtilityRequest.GET_PARALLEL_CONFIG, + RPCUtilityRequest.GET_DECODING_CONFIG, + RPCUtilityRequest.GET_SCHEDULER_CONFIG, + RPCUtilityRequest.GET_LORA_CONFIG + ]: + return self.get_config(identity, request) elif request == RPCUtilityRequest.DO_LOG_STATS: return self.do_log_stats(identity) elif request == RPCUtilityRequest.IS_SERVER_READY: return self.is_server_ready(identity) - elif request == RPCUtilityRequest.CHECK_HEALTH: + elif request == RPCUtilityRequest.IS_SERVER_HEALTHY: return self.check_health(identity) elif request == RPCUtilityRequest.IS_TRACING_ENABLED: return self.is_tracing_enabled(identity)