From 30a4f4dff1d1ba132e4142e6fc8255bc233c0048 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 31 Jul 2024 20:19:49 +0000 Subject: [PATCH 1/9] pass configs --- vllm/engine/async_llm_engine.py | 27 +++- vllm/engine/llm_engine.py | 34 +++-- vllm/entrypoints/openai/api_server.py | 2 +- vllm/entrypoints/openai/rpc/__init__.py | 7 +- vllm/entrypoints/openai/rpc/client.py | 127 +++++++++++++----- vllm/entrypoints/openai/rpc/server.py | 50 ++++++- .../tokenizer_group/__init__.py | 20 ++- 7 files changed, 205 insertions(+), 62 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 16b7bc64a2849..849e3095d1b70 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -7,7 +7,8 @@ from transformers import PreTrainedTokenizer import vllm.envs as envs -from vllm.config import DecodingConfig, EngineConfig, ModelConfig +from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig) from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout @@ -923,6 +924,14 @@ async def get_model_config(self) -> ModelConfig: return await self.engine.get_model_config.remote() # type: ignore else: return self.engine.get_model_config() + + async def get_parallel_config(self) -> ParallelConfig: + """Get the parallel configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_parallel_config.remote( # type: ignore + ) + else: + return self.engine.get_parallel_config() async def get_decoding_config(self) -> DecodingConfig: """Get the decoding configuration of the vLLM engine.""" @@ -932,6 +941,22 @@ async def get_decoding_config(self) -> DecodingConfig: else: return self.engine.get_decoding_config() + async def get_scheduler_config(self) -> SchedulerConfig: + """Get the scheduling configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_scheduler_config.remote( # type: ignore + ) + else: + return self.engine.get_scheduler_config() + + async def get_lora_config(self) -> LoRAConfig: + """Get the lora configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_lora_config.remote( # type: ignore + ) + else: + return self.engine.get_lora_config() + async def do_log_stats( self, scheduler_outputs: Optional[SchedulerOutputs] = None, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 48d5305892219..aebd03d6e1d04 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -41,7 +41,7 @@ from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, - get_tokenizer_group) + _init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import Counter @@ -481,19 +481,13 @@ def get_tokenizer_for_seq(self, return self.get_tokenizer_group().get_lora_tokenizer( sequence.lora_request) - def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup: - init_kwargs = dict( - tokenizer_id=self.model_config.tokenizer, - enable_lora=bool(self.lora_config), - max_num_seqs=self.scheduler_config.max_num_seqs, - max_input_length=None, - tokenizer_mode=self.model_config.tokenizer_mode, - trust_remote_code=self.model_config.trust_remote_code, - revision=self.model_config.tokenizer_revision) - init_kwargs.update(tokenizer_init_kwargs) - - return get_tokenizer_group(self.parallel_config.tokenizer_pool_config, - **init_kwargs) + def _init_tokenizer(self) -> BaseTokenizerGroup: + return _init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=self.scheduler_config, + parallel_config=self.parallel_config, + enable_lora=bool(self.lora_config) + ) def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) @@ -755,9 +749,21 @@ def get_model_config(self) -> ModelConfig: """Gets the model configuration.""" return self.model_config + def get_parallel_config(self) -> ParallelConfig: + """Gets the parallel configuration.""" + return self.get_parallel_config + def get_decoding_config(self) -> DecodingConfig: """Gets the decoding configuration.""" return self.decoding_config + + def get_scheduler_config(self) -> SchedulerConfig: + """Gets the scheduler configuration.""" + return self.scheduler_config + + def get_lora_config(self) -> LoRAConfig: + """Gets the LoRA configuration.""" + return self.lora_config def get_num_unfinished_requests(self) -> int: """Gets the number of unfinished requests.""" diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 772738351cda5..631e47fb899ed 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -119,7 +119,7 @@ async def build_backend(args) -> AsyncIterator[VLLMBackend]: backend = RPCClient(tokenizer=AutoTokenizer.from_pretrained( args.model), port=port) - await backend.wait_for_server() + await backend.connect_to_server() try: yield backend diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index 7187bcdbe77bc..4712ac00cd3c4 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -29,8 +29,11 @@ class RPCAbortRequest: class RPCUtilityRequest(Enum): IS_SERVER_READY = 1 GET_MODEL_CONFIG = 2 - DO_LOG_STATS = 3 - CHECK_HEALTH = 4 + GET_DECODING_CONFIG = 3 + GET_PARALLEL_CONFIG = 4 + GET_LORA_CONFIG = 5 + DO_LOG_STATS = 6 + CHECK_HEALTH = 7 RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 9bcdf6c48bbcc..6fc7a85a59abe 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -1,37 +1,72 @@ import pickle -from typing import AsyncIterator, Mapping, Optional +from typing import Any, AsyncIterator, Mapping, Optional import zmq import zmq.asyncio -from vllm.config import DecodingConfig, ModelConfig -from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, - VLLM_RPC_HEALTHY_STR, - VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCGenerateRequest, RPCUtilityRequest) +from vllm.config import (DecodingConfig, ModelConfig, ParallelConfig, + LoRAConfig, SchedulerConfig) +from vllm.entrypoints.openai.rpc import ( + RPC_REQUEST_TYPE, VLLM_RPC_HEALTHY_STR, VLLM_RPC_SUCCESS_STR, + RPCAbortRequest, RPCGenerateRequest, RPCUtilityRequest) from vllm.inputs import PromptInputs from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams - +from vllm.transformers_utils.tokenizer_group import _init_tokenizer_from_configs class RPCClient: - - # TODO: check if opening all these sockets is an antipattern? - def __init__(self, tokenizer, port: int): - # ZMQ context. + def __init__(self, port: int): self.context = zmq.asyncio.Context() - - # TODO: do the tokenizer properly. - self.tokenizer = tokenizer - self.decoding_config = DecodingConfig() self.path = f"tcp://localhost:{port}" + + async def setup(self): + """Setup the client before it starts sending server requests.""" + + # Wait until server is ready. + await self.wait_for_server() + + # Get the configs. + self.model_config = await self.get_model_config() + self.decoding_config = await self.get_decoding_config() + + # Create the tokenizer group. + self.tokenizer_group = _init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=(await self.get_scheduler_config), + parallel_config=(await self.get_parallel_config()), + enable_lora=bool(await self.get_lora_config), + ) + def close(self): """Destroy the ZeroMQ Context.""" self.context.destroy() + async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, + expected_type: Any, + error_message: str) -> Any: + """Send an RPC request that is expecting data back.""" + + # Connect to socket. + socket = self.context.socket(zmq.constants.DEALER) + socket.connect(self.path) + + # Ping RPCServer with GET_MODEL_CONFIG request. + await socket.send(pickle.dumps(request)) + + # Await the MODEL_CONFIG from the Server. + data = pickle.loads(await socket.recv()) + + if not isinstance(data, expected_type): + socket.close() + raise ValueError(error_message) + + socket.close() + + return data + async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, error_message: str): """Send one-way RPC request to trigger an action.""" @@ -58,10 +93,6 @@ async def get_tokenizer(self, lora_request: LoRARequest): # TODO: handle this via get data? - or avoid doing via RPC return self.tokenizer - async def get_decoding_config(self): - # TODO: handle this via get data? - or avoid doing via RPC - return self.decoding_config - async def is_tracing_enabled(self): # TODO: what is this? return False @@ -76,27 +107,51 @@ async def wait_for_server(self): async def get_model_config(self) -> ModelConfig: """Get the ModelConfig object from the RPC Server""" - # Connect to socket. - socket = self.context.socket(zmq.constants.DEALER) - socket.connect(self.path) - - # Ping RPCServer with GET_MODEL_CONFIG request. - await socket.send(pickle.dumps(RPCUtilityRequest.GET_MODEL_CONFIG)) - - # Await the MODEL_CONFIG from the Server. - model_config = pickle.loads(await socket.recv()) - - if not isinstance(model_config, ModelConfig): - socket.close() - raise ValueError("Expected ModelConfig object from RPC, but " - f"got {model_config}") + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_MODEL_CONFIG, + expected_type=ModelConfig, + error_message="Could not get ModelConfig from RPC Server" + ) - socket.close() + async def get_decoding_config(self): + """Get DecodingConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_DECODING_CONFIG, + expected_type=ModelConfig, + error_message="Could not get DecodingConfig from RPC Server" + ) + + async def get_parallel_config(self): + """Get ParallelConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_PARALLEL_CONFIG, + expected_type=ModelConfig, + error_message="Could not get ModelConfig from RPC Server" + ) + + async def get_scheduler_config(self): + """Get SchedulerConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_SCHEDULER_CONFIG, + expected_type=SchedulerConfig, + error_message="Could not get SchedulerConfig from RPC Server" + ) + + async def get_lora_config(self): + """Get LoRAConfig from the RPCServer""" + + return await self._send_get_data_rpc_request( + RPCUtilityRequest.GET_SCHEDULER_CONFIG, + expected_type=LoRAConfig, + error_message="Could not get LoRAConfig from RPC Server" + ) - return model_config async def abort(self, request_id: str): - """Send an RPCAbortRequest to the RPC Server""" + """Send an ABORT_REQUEST signal to the RPC Server""" await self._send_one_way_rpc_request( request=RPCAbortRequest(request_id), diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 73ae2aae06ea1..8942d956c52ce 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -37,22 +37,50 @@ def cleanup(self): self.socket.close() self.context.destroy() - async def _send_success_message(self, identity): - """Send message to client indicating an action was successful.""" - await self.socket.send_multipart([ - identity, - pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL), - ]) + async def get_model_config(self, identity): - """Send the ModelConfig """ + """Send the ModelConfig""" model_config = await self.engine.get_model_config() await self.socket.send_multipart( [identity, pickle.dumps(model_config, pickle.HIGHEST_PROTOCOL)]) + async def get_decoding_config(self, identity): + """Send the DecodingConfig""" + decoding_config = await self.engine.get_decoding_config() + + await self.socket.send_multipart( + [identity, + pickle.dumps(decoding_config, pickle.HIGHEST_PROTOCOL)]) + + async def get_lora_config(self, identity): + """Send the LoRAConfig""" + lora_config = await self.engine.get_lora_config() + + await self.socket.send_multipart( + [identity, + pickle.dumps(lora_config, pickle.HIGHEST_PROTOCOL)]) + + async def get_scheduler_config(self, identity): + """Send the SchedulerConfig""" + parallel_config = await self.engine.get_scheduler_config() + + await self.socket.send_multipart( + [identity, + pickle.dumps(parallel_config, pickle.HIGHEST_PROTOCOL)]) + + async def get_parallel_config(self, identity): + """Send the ParallelConfig""" + parallel_config = await self.engine.get_parallel_config() + + await self.socket.send_multipart( + [identity, + pickle.dumps(parallel_config, pickle.HIGHEST_PROTOCOL)]) + async def do_log_stats(self, identity): + """Log stats and confirm success.""" await self.engine.do_log_stats() await self.socket.send_multipart([ @@ -61,12 +89,14 @@ async def do_log_stats(self, identity): ]) async def is_server_ready(self, identity): + """Notify the client that we are ready.""" await self.socket.send_multipart([ identity, pickle.dumps(VLLM_RPC_SUCCESS_STR, pickle.HIGHEST_PROTOCOL), ]) async def abort(self, identity, request: RPCAbortRequest): + """Abort request and notify the client of success.""" # Abort the request in the llm engine. await self.engine.abort(request.request_id) @@ -120,6 +150,12 @@ def _make_handler_coro(self, identity, elif isinstance(request, RPCUtilityRequest): if request == RPCUtilityRequest.GET_MODEL_CONFIG: return self.get_model_config(identity) + elif request == RPCUtilityRequest.GET_PARALLEL_CONFIG: + return self.get_parallel_config(identity) + elif request == RPCUtilityRequest.GET_DECODING_CONFIG: + return self.get_decoding_config(identity) + elif request == RPCUtilityRequest.GET_LORA_CONFIG: + return self.get_lora_config(identity) elif request == RPCUtilityRequest.DO_LOG_STATS: return self.do_log_stats(identity) elif request == RPCUtilityRequest.IS_SERVER_READY: diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index 9f54f5409b181..f9af8c3507520 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -1,6 +1,7 @@ from typing import Optional, Type -from vllm.config import TokenizerPoolConfig +from vllm.config import (TokenizerPoolConfig, ModelConfig, + ParallelConfig, SchedulerConfig) from vllm.executor.ray_utils import ray from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( BaseTokenizerGroup) @@ -14,6 +15,23 @@ RayTokenizerGroupPool = None # type: ignore +def _init_tokenizer_from_configs(model_config: ModelConfig, + scheduler_config: SchedulerConfig, + parallel_config: ParallelConfig, + enable_lora: bool): + init_kwargs = dict( + tokenizer_id=model_config.tokenizer, + enable_lora=enable_lora, + max_num_seqs=scheduler_config.max_num_seqs, + max_input_length=None, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, + revision=model_config.tokenizer_revision) + + return get_tokenizer_group(parallel_config.tokenizer_pool_config, + **init_kwargs) + + def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], **init_kwargs) -> BaseTokenizerGroup: tokenizer_cls: Type[BaseTokenizerGroup] From bd27519d29f97c782534b54e3621aaab2f18ebfc Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 31 Jul 2024 20:24:47 +0000 Subject: [PATCH 2/9] almost there --- vllm/entrypoints/openai/rpc/__init__.py | 7 ++++--- vllm/entrypoints/openai/rpc/client.py | 6 ++---- vllm/entrypoints/openai/rpc/server.py | 2 ++ 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/vllm/entrypoints/openai/rpc/__init__.py b/vllm/entrypoints/openai/rpc/__init__.py index 4712ac00cd3c4..0c055b76fe2aa 100644 --- a/vllm/entrypoints/openai/rpc/__init__.py +++ b/vllm/entrypoints/openai/rpc/__init__.py @@ -31,9 +31,10 @@ class RPCUtilityRequest(Enum): GET_MODEL_CONFIG = 2 GET_DECODING_CONFIG = 3 GET_PARALLEL_CONFIG = 4 - GET_LORA_CONFIG = 5 - DO_LOG_STATS = 6 - CHECK_HEALTH = 7 + GET_SCHEDULER_CONFIG = 5 + GET_LORA_CONFIG = 6 + DO_LOG_STATS = 7 + CHECK_HEALTH = 8 RPC_REQUEST_TYPE = Union[RPCGenerateRequest, RPCAbortRequest, diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 6fc7a85a59abe..dc0e41f83c497 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -90,8 +90,7 @@ async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, return response async def get_tokenizer(self, lora_request: LoRARequest): - # TODO: handle this via get data? - or avoid doing via RPC - return self.tokenizer + await self.tokenizer.get_lora_tokenizer_async(lora_request) async def is_tracing_enabled(self): # TODO: what is this? @@ -144,12 +143,11 @@ async def get_lora_config(self): """Get LoRAConfig from the RPCServer""" return await self._send_get_data_rpc_request( - RPCUtilityRequest.GET_SCHEDULER_CONFIG, + RPCUtilityRequest.GET_LORA_CONFIG, expected_type=LoRAConfig, error_message="Could not get LoRAConfig from RPC Server" ) - async def abort(self, request_id: str): """Send an ABORT_REQUEST signal to the RPC Server""" diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 8942d956c52ce..6d72cec50a977 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -154,6 +154,8 @@ def _make_handler_coro(self, identity, return self.get_parallel_config(identity) elif request == RPCUtilityRequest.GET_DECODING_CONFIG: return self.get_decoding_config(identity) + elif request == RPCUtilityRequest.GET_SCHEDULER_CONFIG: + return self.get_scheduler_config(identity) elif request == RPCUtilityRequest.GET_LORA_CONFIG: return self.get_lora_config(identity) elif request == RPCUtilityRequest.DO_LOG_STATS: From 11d4de5116aa487184542329c63470ddb03cf13e Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 31 Jul 2024 20:48:18 +0000 Subject: [PATCH 3/9] formatted --- vllm/engine/async_llm_engine.py | 8 +- vllm/engine/llm_engine.py | 11 ++- vllm/entrypoints/openai/api_server.py | 7 +- vllm/entrypoints/openai/rpc/client.py | 86 ++++++++++--------- vllm/entrypoints/openai/rpc/server.py | 9 +- .../tokenizer_group/__init__.py | 21 +++-- 6 files changed, 71 insertions(+), 71 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 849e3095d1b70..0584d8eb6f323 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -7,8 +7,8 @@ from transformers import PreTrainedTokenizer import vllm.envs as envs -from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig) +from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout @@ -924,7 +924,7 @@ async def get_model_config(self) -> ModelConfig: return await self.engine.get_model_config.remote() # type: ignore else: return self.engine.get_model_config() - + async def get_parallel_config(self) -> ParallelConfig: """Get the parallel configuration of the vLLM engine.""" if self.engine_use_ray: @@ -948,7 +948,7 @@ async def get_scheduler_config(self) -> SchedulerConfig: ) else: return self.engine.get_scheduler_config() - + async def get_lora_config(self) -> LoRAConfig: """Get the lora configuration of the vLLM engine.""" if self.engine_use_ray: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index aebd03d6e1d04..6215599f016e8 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -40,8 +40,8 @@ init_tracer) from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.detokenizer import Detokenizer -from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, - _init_tokenizer_from_configs) +from vllm.transformers_utils.tokenizer_group import ( + BaseTokenizerGroup, _init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import Counter @@ -486,8 +486,7 @@ def _init_tokenizer(self) -> BaseTokenizerGroup: model_config=self.model_config, scheduler_config=self.scheduler_config, parallel_config=self.parallel_config, - enable_lora=bool(self.lora_config) - ) + enable_lora=bool(self.lora_config)) def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) @@ -751,12 +750,12 @@ def get_model_config(self) -> ModelConfig: def get_parallel_config(self) -> ParallelConfig: """Gets the parallel configuration.""" - return self.get_parallel_config + return self.parallel_config def get_decoding_config(self) -> DecodingConfig: """Gets the decoding configuration.""" return self.decoding_config - + def get_scheduler_config(self) -> SchedulerConfig: """Gets the scheduler configuration.""" return self.scheduler_config diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 631e47fb899ed..07885260dfc76 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -16,7 +16,6 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse from prometheus_client import make_asgi_app from starlette.routing import Mount -from transformers import AutoTokenizer import vllm.envs as envs from vllm.config import ModelConfig @@ -116,10 +115,8 @@ async def build_backend(args) -> AsyncIterator[VLLMBackend]: ## Then build the client for the backend process # TODO: figure out a way around passing the tokenizer - backend = RPCClient(tokenizer=AutoTokenizer.from_pretrained( - args.model), - port=port) - await backend.connect_to_server() + backend = RPCClient(port) + await backend.setup() try: yield backend diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index dc0e41f83c497..a91a356a480fa 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -4,42 +4,46 @@ import zmq import zmq.asyncio -from vllm.config import (DecodingConfig, ModelConfig, ParallelConfig, - LoRAConfig, SchedulerConfig) -from vllm.entrypoints.openai.rpc import ( - RPC_REQUEST_TYPE, VLLM_RPC_HEALTHY_STR, VLLM_RPC_SUCCESS_STR, - RPCAbortRequest, RPCGenerateRequest, RPCUtilityRequest) +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +from vllm.entrypoints.openai.rpc import (RPC_REQUEST_TYPE, + VLLM_RPC_HEALTHY_STR, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCGenerateRequest, RPCUtilityRequest) from vllm.inputs import PromptInputs from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import _init_tokenizer_from_configs +from vllm.transformers_utils.tokenizer_group import ( + _init_tokenizer_from_configs) + class RPCClient: + def __init__(self, port: int): self.context = zmq.asyncio.Context() self.path = f"tcp://localhost:{port}" - + async def setup(self): """Setup the client before it starts sending server requests.""" - + # Wait until server is ready. await self.wait_for_server() # Get the configs. - self.model_config = await self.get_model_config() - self.decoding_config = await self.get_decoding_config() + self.model_config = await self._get_model_config_rpc() + self.decoding_config = await self._get_decoding_config_rpc() # Create the tokenizer group. - self.tokenizer_group = _init_tokenizer_from_configs( + # Note: this is a hack until we fully + self.tokenizer = _init_tokenizer_from_configs( model_config=self.model_config, - scheduler_config=(await self.get_scheduler_config), - parallel_config=(await self.get_parallel_config()), - enable_lora=bool(await self.get_lora_config), + scheduler_config=(await self._get_scheduler_config_rpc()), + parallel_config=(await self._get_parallel_config_rpc()), + enable_lora=bool(await self._get_lora_config_rpc()), ) - def close(self): """Destroy the ZeroMQ Context.""" self.context.destroy() @@ -53,15 +57,18 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, socket = self.context.socket(zmq.constants.DEALER) socket.connect(self.path) - # Ping RPCServer with GET_MODEL_CONFIG request. + # Ping RPCServer with a request. await socket.send(pickle.dumps(request)) - # Await the MODEL_CONFIG from the Server. + # Await the data from the Server. data = pickle.loads(await socket.recv()) - if not isinstance(data, expected_type): - socket.close() - raise ValueError(error_message) + # LoRAConfig can be None. + if expected_type == LoRAConfig and data is None: + pass + else: + socket.close() + raise ValueError(error_message) socket.close() @@ -90,7 +97,13 @@ async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, return response async def get_tokenizer(self, lora_request: LoRARequest): - await self.tokenizer.get_lora_tokenizer_async(lora_request) + return await self.tokenizer.get_lora_tokenizer_async(lora_request) + + async def get_decoding_config(self): + return self.decoding_config + + async def get_model_config(self): + return self.model_config async def is_tracing_enabled(self): # TODO: what is this? @@ -103,50 +116,45 @@ async def wait_for_server(self): request=RPCUtilityRequest.IS_SERVER_READY, error_message="Unable to start RPC Server.") - async def get_model_config(self) -> ModelConfig: + async def _get_model_config_rpc(self) -> ModelConfig: """Get the ModelConfig object from the RPC Server""" return await self._send_get_data_rpc_request( RPCUtilityRequest.GET_MODEL_CONFIG, expected_type=ModelConfig, - error_message="Could not get ModelConfig from RPC Server" - ) + error_message="Could not get ModelConfig from RPC Server") - async def get_decoding_config(self): + async def _get_decoding_config_rpc(self) -> DecodingConfig: """Get DecodingConfig from the RPCServer""" return await self._send_get_data_rpc_request( RPCUtilityRequest.GET_DECODING_CONFIG, - expected_type=ModelConfig, - error_message="Could not get DecodingConfig from RPC Server" - ) + expected_type=DecodingConfig, + error_message="Could not get DecodingConfig from RPC Server") - async def get_parallel_config(self): + async def _get_parallel_config_rpc(self) -> ParallelConfig: """Get ParallelConfig from the RPCServer""" return await self._send_get_data_rpc_request( RPCUtilityRequest.GET_PARALLEL_CONFIG, - expected_type=ModelConfig, - error_message="Could not get ModelConfig from RPC Server" - ) - - async def get_scheduler_config(self): + expected_type=ParallelConfig, + error_message="Could not get ModelConfig from RPC Server") + + async def _get_scheduler_config_rpc(self) -> SchedulerConfig: """Get SchedulerConfig from the RPCServer""" return await self._send_get_data_rpc_request( RPCUtilityRequest.GET_SCHEDULER_CONFIG, expected_type=SchedulerConfig, - error_message="Could not get SchedulerConfig from RPC Server" - ) + error_message="Could not get SchedulerConfig from RPC Server") - async def get_lora_config(self): + async def _get_lora_config_rpc(self): """Get LoRAConfig from the RPCServer""" return await self._send_get_data_rpc_request( RPCUtilityRequest.GET_LORA_CONFIG, expected_type=LoRAConfig, - error_message="Could not get LoRAConfig from RPC Server" - ) + error_message="Could not get LoRAConfig from RPC Server") async def abort(self, request_id: str): """Send an ABORT_REQUEST signal to the RPC Server""" diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 6d72cec50a977..d9dd21184c783 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -37,8 +37,6 @@ def cleanup(self): self.socket.close() self.context.destroy() - - async def get_model_config(self, identity): """Send the ModelConfig""" model_config = await self.engine.get_model_config() @@ -54,9 +52,8 @@ async def get_decoding_config(self, identity): await self.socket.send_multipart( [identity, pickle.dumps(decoding_config, pickle.HIGHEST_PROTOCOL)]) - + async def get_lora_config(self, identity): - """Send the LoRAConfig""" lora_config = await self.engine.get_lora_config() await self.socket.send_multipart( @@ -69,7 +66,7 @@ async def get_scheduler_config(self, identity): await self.socket.send_multipart( [identity, - pickle.dumps(parallel_config, pickle.HIGHEST_PROTOCOL)]) + pickle.dumps(parallel_config, pickle.HIGHEST_PROTOCOL)]) async def get_parallel_config(self, identity): """Send the ParallelConfig""" @@ -77,7 +74,7 @@ async def get_parallel_config(self, identity): await self.socket.send_multipart( [identity, - pickle.dumps(parallel_config, pickle.HIGHEST_PROTOCOL)]) + pickle.dumps(parallel_config, pickle.HIGHEST_PROTOCOL)]) async def do_log_stats(self, identity): """Log stats and confirm success.""" diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index f9af8c3507520..9ee2a4c9f40c7 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -1,7 +1,7 @@ from typing import Optional, Type -from vllm.config import (TokenizerPoolConfig, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, + TokenizerPoolConfig) from vllm.executor.ray_utils import ray from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( BaseTokenizerGroup) @@ -19,15 +19,14 @@ def _init_tokenizer_from_configs(model_config: ModelConfig, scheduler_config: SchedulerConfig, parallel_config: ParallelConfig, enable_lora: bool): - init_kwargs = dict( - tokenizer_id=model_config.tokenizer, - enable_lora=enable_lora, - max_num_seqs=scheduler_config.max_num_seqs, - max_input_length=None, - tokenizer_mode=model_config.tokenizer_mode, - trust_remote_code=model_config.trust_remote_code, - revision=model_config.tokenizer_revision) - + init_kwargs = dict(tokenizer_id=model_config.tokenizer, + enable_lora=enable_lora, + max_num_seqs=scheduler_config.max_num_seqs, + max_input_length=None, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, + revision=model_config.tokenizer_revision) + return get_tokenizer_group(parallel_config.tokenizer_pool_config, **init_kwargs) From 9dadae6745e3cae3e0c0e35a3a983cf580d26bac Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 31 Jul 2024 20:50:27 +0000 Subject: [PATCH 4/9] comment --- vllm/entrypoints/openai/api_server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 07885260dfc76..104f70f1386ad 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -114,7 +114,6 @@ async def build_backend(args) -> AsyncIterator[VLLMBackend]: rpc_server_process.start() ## Then build the client for the backend process - # TODO: figure out a way around passing the tokenizer backend = RPCClient(port) await backend.setup() From 8edd734025070bcfc4a68deeefab34fe8f5b26d6 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 31 Jul 2024 20:51:44 +0000 Subject: [PATCH 5/9] better comment --- vllm/entrypoints/openai/rpc/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index a91a356a480fa..11db54e2779cd 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -36,7 +36,7 @@ async def setup(self): self.decoding_config = await self._get_decoding_config_rpc() # Create the tokenizer group. - # Note: this is a hack until we fully + # TODO: refactor OAI server to avoid needing this info. self.tokenizer = _init_tokenizer_from_configs( model_config=self.model_config, scheduler_config=(await self._get_scheduler_config_rpc()), From a4fc498a9d9077439ad0c04c9e25d568c7944008 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 31 Jul 2024 20:52:10 +0000 Subject: [PATCH 6/9] another typo --- vllm/entrypoints/openai/rpc/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 11db54e2779cd..263dd23f2d7bc 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -63,7 +63,7 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, # Await the data from the Server. data = pickle.loads(await socket.recv()) if not isinstance(data, expected_type): - # LoRAConfig can be None. + # LoRAConfig can be None. if expected_type == LoRAConfig and data is None: pass else: From b73c5da309303d07934200ec115bcbdc2445e09b Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 31 Jul 2024 21:24:38 +0000 Subject: [PATCH 7/9] fix provate method --- vllm/engine/llm_engine.py | 4 ++-- vllm/entrypoints/openai/rpc/client.py | 5 ++--- vllm/transformers_utils/tokenizer_group/__init__.py | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 6215599f016e8..627f028b99d7e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -41,7 +41,7 @@ from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import ( - BaseTokenizerGroup, _init_tokenizer_from_configs) + BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import Counter @@ -482,7 +482,7 @@ def get_tokenizer_for_seq(self, sequence.lora_request) def _init_tokenizer(self) -> BaseTokenizerGroup: - return _init_tokenizer_from_configs( + return init_tokenizer_from_configs( model_config=self.model_config, scheduler_config=self.scheduler_config, parallel_config=self.parallel_config, diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 263dd23f2d7bc..f69e7c24b449e 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -15,8 +15,7 @@ from vllm.outputs import RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import ( - _init_tokenizer_from_configs) +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs class RPCClient: @@ -37,7 +36,7 @@ async def setup(self): # Create the tokenizer group. # TODO: refactor OAI server to avoid needing this info. - self.tokenizer = _init_tokenizer_from_configs( + self.tokenizer = init_tokenizer_from_configs( model_config=self.model_config, scheduler_config=(await self._get_scheduler_config_rpc()), parallel_config=(await self._get_parallel_config_rpc()), diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index 9ee2a4c9f40c7..2a07edde3f694 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -15,7 +15,7 @@ RayTokenizerGroupPool = None # type: ignore -def _init_tokenizer_from_configs(model_config: ModelConfig, +def _configs(model_config: ModelConfig, scheduler_config: SchedulerConfig, parallel_config: ParallelConfig, enable_lora: bool): From 7266719b63c6cd4b46e710a7d68467ca6e94fea1 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 31 Jul 2024 21:34:14 +0000 Subject: [PATCH 8/9] fix provate method --- vllm/transformers_utils/tokenizer_group/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index 2a07edde3f694..ae17ccf056b96 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -15,7 +15,7 @@ RayTokenizerGroupPool = None # type: ignore -def _configs(model_config: ModelConfig, +def init_tokenizer_from_configs(model_config: ModelConfig, scheduler_config: SchedulerConfig, parallel_config: ParallelConfig, enable_lora: bool): From 38f6568c4d66a5b22b1f95657dc716418259f607 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 31 Jul 2024 21:52:09 +0000 Subject: [PATCH 9/9] fixed plumbing --- tests/entrypoints/openai/test_completion.py | 1 + vllm/entrypoints/openai/rpc/server.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index fe00640c0021e..521a450f13568 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -119,6 +119,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, choice = completion.choices[0] assert len(choice.text) >= 5 assert choice.finish_reason == "length" + print(completion.usage) assert completion.usage == openai.types.CompletionUsage( completion_tokens=5, prompt_tokens=6 + num_virtual_tokens, diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index d9dd21184c783..ca57295c69965 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -108,7 +108,10 @@ async def generate(self, identity, generate_request: RPCGenerateRequest): results_generator = self.engine.generate( generate_request.inputs, sampling_params=generate_request.sampling_params, - request_id=generate_request.request_id) + request_id=generate_request.request_id, + lora_request=generate_request.lora_request, + trace_headers=generate_request.trace_headers, + prompt_adapter_request=generate_request.prompt_adapter_request) async for request_output in results_generator: await self.socket.send_multipart([