diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index fe00640c0021e..521a450f13568 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -119,6 +119,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, choice = completion.choices[0] assert len(choice.text) >= 5 assert choice.finish_reason == "length" + print(completion.usage) assert completion.usage == openai.types.CompletionUsage( completion_tokens=5, prompt_tokens=6 + num_virtual_tokens, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 16b7bc64a2849..0584d8eb6f323 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -7,7 +7,8 @@ from transformers import PreTrainedTokenizer import vllm.envs as envs -from vllm.config import DecodingConfig, EngineConfig, ModelConfig +from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout @@ -924,6 +925,14 @@ async def get_model_config(self) -> ModelConfig: else: return self.engine.get_model_config() + async def get_parallel_config(self) -> ParallelConfig: + """Get the parallel configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_parallel_config.remote( # type: ignore + ) + else: + return self.engine.get_parallel_config() + async def get_decoding_config(self) -> DecodingConfig: """Get the decoding configuration of the vLLM engine.""" if self.engine_use_ray: @@ -932,6 +941,22 @@ async def get_decoding_config(self) -> DecodingConfig: else: return self.engine.get_decoding_config() + async def get_scheduler_config(self) -> SchedulerConfig: + """Get the scheduling configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_scheduler_config.remote( # type: ignore + ) + else: + return self.engine.get_scheduler_config() + + async def get_lora_config(self) -> LoRAConfig: + """Get the lora configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_lora_config.remote( # type: ignore + ) + else: + return self.engine.get_lora_config() + async def do_log_stats( self, scheduler_outputs: Optional[SchedulerOutputs] = None, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 48d5305892219..627f028b99d7e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -40,8 +40,8 @@ init_tracer) from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, - get_tokenizer_group) +from vllm.transformers_utils.tokenizer_group import ( + BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import Counter @@ -481,19 +481,12 @@ def get_tokenizer_for_seq(self, return self.get_tokenizer_group().get_lora_tokenizer( sequence.lora_request) - def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup: - init_kwargs = dict( - tokenizer_id=self.model_config.tokenizer, - enable_lora=bool(self.lora_config), - max_num_seqs=self.scheduler_config.max_num_seqs, - max_input_length=None, - tokenizer_mode=self.model_config.tokenizer_mode, - trust_remote_code=self.model_config.trust_remote_code, - revision=self.model_config.tokenizer_revision) - init_kwargs.update(tokenizer_init_kwargs) - - return get_tokenizer_group(self.parallel_config.tokenizer_pool_config, - **init_kwargs) + def _init_tokenizer(self) -> BaseTokenizerGroup: + return init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=self.scheduler_config, + parallel_config=self.parallel_config, + enable_lora=bool(self.lora_config)) def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) @@ -755,10 +748,22 @@ def get_model_config(self) -> ModelConfig: """Gets the model configuration.""" return self.model_config + def get_parallel_config(self) -> ParallelConfig: + """Gets the parallel configuration.""" + return self.parallel_config + def get_decoding_config(self) -> DecodingConfig: """Gets the decoding configuration.""" return self.decoding_config + def get_scheduler_config(self) -> SchedulerConfig: + """Gets the scheduler configuration.""" + return self.scheduler_config + + def get_lora_config(self) -> LoRAConfig: + """Gets the LoRA configuration.""" + return self.lora_config + def get_num_unfinished_requests(self) -> int: """Gets the number of unfinished requests.""" return sum(scheduler.get_num_unfinished_seq_groups() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 772738351cda5..104f70f1386ad 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -16,7 +16,6 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse from prometheus_client import make_asgi_app from starlette.routing import Mount -from transformers import AutoTokenizer import vllm.envs as envs from vllm.config import ModelConfig @@ -115,11 +114,8 @@ async def build_backend(args) -> AsyncIterator[VLLMBackend]: rpc_server_process.start() ## Then build the client for the backend process - # TODO: figure out a way around passing the tokenizer - backend = RPCClient(tokenizer=AutoTokenizer.from_pretrained( - args.model), - port=port) - await backend.wait_for_server() + backend = RPCClient(port) + await backend.setup() try: yield backend diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index 7187bcdbe77bc..0c055b76fe2aa 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -29,8 +29,12 @@ class RPCAbortRequest: class RPCUtilityRequest(Enum): IS_SERVER_READY = 1 GET_MODEL_CONFIG = 2 - DO_LOG_STATS = 3 - CHECK_HEALTH = 4 + GET_DECODING_CONFIG = 3 + GET_PARALLEL_CONFIG = 4 + GET_SCHEDULER_CONFIG = 5 + GET_LORA_CONFIG = 6 + DO_LOG_STATS = 7 + CHECK_HEALTH = 8 RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 9bcdf6c48bbcc..f69e7c24b449e 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -1,10 +1,11 @@ import pickle -from typing import AsyncIterator, Mapping, Optional +from typing import Any, AsyncIterator, Mapping, Optional import zmq import zmq.asyncio -from vllm.config import DecodingConfig, ModelConfig +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, VLLM_RPC_HEALTHY_STR, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, @@ -14,24 +15,64 @@ from vllm.outputs import RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs class RPCClient: - # TODO: check if opening all these sockets is an antipattern? - def __init__(self, tokenizer, port: int): - # ZMQ context. + def __init__(self, port: int): self.context = zmq.asyncio.Context() - - # TODO: do the tokenizer properly. - self.tokenizer = tokenizer - self.decoding_config = DecodingConfig() self.path = f"tcp://localhost:{port}" + async def setup(self): + """Setup the client before it starts sending server requests.""" + + # Wait until server is ready. + await self.wait_for_server() + + # Get the configs. + self.model_config = await self._get_model_config_rpc() + self.decoding_config = await self._get_decoding_config_rpc() + + # Create the tokenizer group. + # TODO: refactor OAI server to avoid needing this info. + self.tokenizer = init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=(await self._get_scheduler_config_rpc()), + parallel_config=(await self._get_parallel_config_rpc()), + enable_lora=bool(await self._get_lora_config_rpc()), + ) + def close(self): """Destroy the ZeroMQ Context.""" self.context.destroy() + async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, + expected_type: Any, + error_message: str) -> Any: + """Send an RPC request that is expecting data back.""" + + # Connect to socket. + socket = self.context.socket(zmq.constants.DEALER) + socket.connect(self.path) + + # Ping RPCServer with a request. + await socket.send(pickle.dumps(request)) + + # Await the data from the Server. + data = pickle.loads(await socket.recv()) + if not isinstance(data, expected_type): + # LoRAConfig can be None. + if expected_type == LoRAConfig and data is None: + pass + else: + socket.close() + raise ValueError(error_message) + + socket.close() + + return data + async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, error_message: str): """Send one-way RPC request to trigger an action.""" @@ -55,13 +96,14 @@ async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, return response async def get_tokenizer(self, lora_request: LoRARequest): - # TODO: handle this via get data? - or avoid doing via RPC - return self.tokenizer + return await self.tokenizer.get_lora_tokenizer_async(lora_request) async def get_decoding_config(self): - # TODO: handle this via get data? - or avoid doing via RPC return self.decoding_config + async def get_model_config(self): + return self.model_config + async def is_tracing_enabled(self): # TODO: what is this? return False @@ -73,30 +115,48 @@ async def wait_for_server(self): request=RPCUtilityRequest.IS_SERVER_READY, error_message="Unable to start RPC Server.") - async def get_model_config(self) -> ModelConfig: + async def _get_model_config_rpc(self) -> ModelConfig: """Get the ModelConfig object from the RPC Server""" - # Connect to socket. - socket = self.context.socket(zmq.constants.DEALER) - socket.connect(self.path) + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_MODEL_CONFIG, + expected_type=ModelConfig, + error_message="Could not get ModelConfig from RPC Server") - # Ping RPCServer with GET_MODEL_CONFIG request. - await socket.send(pickle.dumps(RPCUtilityRequest.GET_MODEL_CONFIG)) + async def _get_decoding_config_rpc(self) -> DecodingConfig: + """Get DecodingConfig from the RPCServer""" - # Await the MODEL_CONFIG from the Server. - model_config = pickle.loads(await socket.recv()) + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_DECODING_CONFIG, + expected_type=DecodingConfig, + error_message="Could not get DecodingConfig from RPC Server") - if not isinstance(model_config, ModelConfig): - socket.close() - raise ValueError("Expected ModelConfig object from RPC, but " - f"got {model_config}") + async def _get_parallel_config_rpc(self) -> ParallelConfig: + """Get ParallelConfig from the RPCServer""" - socket.close() + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_PARALLEL_CONFIG, + expected_type=ParallelConfig, + error_message="Could not get ModelConfig from RPC Server") + + async def _get_scheduler_config_rpc(self) -> SchedulerConfig: + """Get SchedulerConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_SCHEDULER_CONFIG, + expected_type=SchedulerConfig, + error_message="Could not get SchedulerConfig from RPC Server") + + async def _get_lora_config_rpc(self): + """Get LoRAConfig from the RPCServer""" - return model_config + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_LORA_CONFIG, + expected_type=LoRAConfig, + error_message="Could not get LoRAConfig from RPC Server") async def abort(self, request_id: str): - """Send an RPCAbortRequest to the RPC Server""" + """Send an ABORT_REQUEST signal to the RPC Server""" await self._send_one_way_rpc_request( request=RPCAbortRequest(request_id), diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 73ae2aae06ea1..ca57295c69965 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -37,22 +37,47 @@ def cleanup(self): self.socket.close() self.context.destroy() - async def _send_success_message(self, identity): - """Send message to client indicating an action was successful.""" - await self.socket.send_multipart([ - identity, - pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL), - ]) - async def get_model_config(self, identity): - """Send the ModelConfig """ + """Send the ModelConfig""" model_config = await self.engine.get_model_config() await self.socket.send_multipart( [identity, pickle.dumps(model_config, pickle.HIGHEST_PROTOCOL)]) + async def get_decoding_config(self, identity): + """Send the DecodingConfig""" + decoding_config = await self.engine.get_decoding_config() + + await self.socket.send_multipart( + [identity, + pickle.dumps(decoding_config, pickle.HIGHEST_PROTOCOL)]) + + async def get_lora_config(self, identity): + lora_config = await self.engine.get_lora_config() + + await self.socket.send_multipart( + [identity, + pickle.dumps(lora_config, pickle.HIGHEST_PROTOCOL)]) + + async def get_scheduler_config(self, identity): + """Send the SchedulerConfig""" + parallel_config = await self.engine.get_scheduler_config() + + await self.socket.send_multipart( + [identity, + pickle.dumps(parallel_config, pickle.HIGHEST_PROTOCOL)]) + + async def get_parallel_config(self, identity): + """Send the ParallelConfig""" + parallel_config = await self.engine.get_parallel_config() + + await self.socket.send_multipart( + [identity, + pickle.dumps(parallel_config, pickle.HIGHEST_PROTOCOL)]) + async def do_log_stats(self, identity): + """Log stats and confirm success.""" await self.engine.do_log_stats() await self.socket.send_multipart([ @@ -61,12 +86,14 @@ async def do_log_stats(self, identity): ]) async def is_server_ready(self, identity): + """Notify the client that we are ready.""" await self.socket.send_multipart([ identity, pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL), ]) async def abort(self, identity, request: RPCAbortRequest): + """Abort request and notify the client of success.""" # Abort the request in the llm engine. await self.engine.abort(request.request_id) @@ -81,7 +108,10 @@ async def generate(self, identity, generate_request: RPCGenerateRequest): results_generator = self.engine.generate( generate_request.inputs, sampling_params=generate_request.sampling_params, - request_id=generate_request.request_id) + request_id=generate_request.request_id, + lora_request=generate_request.lora_request, + trace_headers=generate_request.trace_headers, + prompt_adapter_request=generate_request.prompt_adapter_request) async for request_output in results_generator: await self.socket.send_multipart([ @@ -120,6 +150,14 @@ def _make_handler_coro(self, identity, elif isinstance(request, RPCUtilityRequest): if request == RPCUtilityRequest.GET_MODEL_CONFIG: return self.get_model_config(identity) + elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: + return self.get_parallel_config(identity) + elif request == RPCUtilityRequest.GET_DECODING_CONFIG: + return self.get_decoding_config(identity) + elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: + return self.get_scheduler_config(identity) + elif request == RPCUtilityRequest.GET_LORA_CONFIG: + return self.get_lora_config(identity) elif request == RPCUtilityRequest.DO_LOG_STATS: return self.do_log_stats(identity) elif request == RPCUtilityRequest.IS_SERVER_READY: diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index 9f54f5409b181..ae17ccf056b96 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -1,6 +1,7 @@ from typing import Optional, Type -from vllm.config import TokenizerPoolConfig +from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, + TokenizerPoolConfig) from vllm.executor.ray_utils import ray from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( BaseTokenizerGroup) @@ -14,6 +15,22 @@ RayTokenizerGroupPool = None # type: ignore +def init_tokenizer_from_configs(model_config: ModelConfig, + scheduler_config: SchedulerConfig, + parallel_config: ParallelConfig, + enable_lora: bool): + init_kwargs = dict(tokenizer_id=model_config.tokenizer, + enable_lora=enable_lora, + max_num_seqs=scheduler_config.max_num_seqs, + max_input_length=None, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, + revision=model_config.tokenizer_revision) + + return get_tokenizer_group(parallel_config.tokenizer_pool_config, + **init_kwargs) + + def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], **init_kwargs) -> BaseTokenizerGroup: tokenizer_cls: Type[BaseTokenizerGroup]