diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 954cec734b956..8dd9b23fbdd5f 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -8,8 +8,8 @@ from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.core import EngineCore +from vllm.v1.executor.abstract import Executor if not current_platform.is_cuda(): pytest.skip(reason="V1 currently only supported on CUDA.", @@ -43,7 +43,7 @@ def test_engine_core(monkeypatch): """Setup the EngineCore.""" engine_args = EngineArgs(model=MODEL_NAME) vllm_config = engine_args.create_engine_config() - executor_class = AsyncLLM._get_executor_cls(vllm_config) + executor_class = Executor.get_class(vllm_config) engine_core = EngineCore(vllm_config=vllm_config, executor_class=executor_class) @@ -149,7 +149,7 @@ def test_engine_core_advanced_sampling(monkeypatch): """Setup the EngineCore.""" engine_args = EngineArgs(model=MODEL_NAME) vllm_config = engine_args.create_engine_config() - executor_class = AsyncLLM._get_executor_cls(vllm_config) + executor_class = Executor.get_class(vllm_config) engine_core = EngineCore(vllm_config=vllm_config, executor_class=executor_class) diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 20d4e6f63b339..5a21806e57a11 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -11,8 +11,8 @@ from vllm.platforms import current_platform from vllm.usage.usage_lib import UsageContext from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.core_client import EngineCoreClient +from vllm.v1.executor.abstract import Executor if not current_platform.is_cuda(): pytest.skip(reason="V1 currently only supported on CUDA.", @@ -84,7 +84,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3) vllm_config = engine_args.create_engine_config( UsageContext.UNKNOWN_CONTEXT) - executor_class = AsyncLLM._get_executor_cls(vllm_config) + executor_class = Executor.get_class(vllm_config) client = EngineCoreClient.make_client( multiprocess_mode=multiprocessing_mode, asyncio_mode=False, @@ -152,7 +152,7 @@ async def test_engine_core_client_asyncio(monkeypatch): engine_args = EngineArgs(model=MODEL_NAME) vllm_config = engine_args.create_engine_config( usage_context=UsageContext.UNKNOWN_CONTEXT) - executor_class = AsyncLLM._get_executor_cls(vllm_config) + executor_class = Executor.get_class(vllm_config) client = EngineCoreClient.make_client( multiprocess_mode=True, asyncio_mode=True, diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 0696caf88385d..b963ba74f13f0 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -22,7 +22,6 @@ from vllm.v1.engine.detokenizer import Detokenizer from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor -from vllm.v1.executor.ray_utils import initialize_ray_cluster logger = init_logger(__name__) @@ -105,7 +104,7 @@ def from_engine_args( else: vllm_config = engine_config - executor_class = cls._get_executor_cls(vllm_config) + executor_class = Executor.get_class(vllm_config) # Create the AsyncLLM. return cls( @@ -127,24 +126,6 @@ def shutdown(self): if handler := getattr(self, "output_handler", None): handler.cancel() - @classmethod - def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]: - executor_class: Type[Executor] - distributed_executor_backend = ( - vllm_config.parallel_config.distributed_executor_backend) - if distributed_executor_backend == "ray": - initialize_ray_cluster(vllm_config.parallel_config) - from vllm.v1.executor.ray_executor import RayExecutor - executor_class = RayExecutor - elif distributed_executor_backend == "mp": - from vllm.v1.executor.multiproc_executor import MultiprocExecutor - executor_class = MultiprocExecutor - else: - assert (distributed_executor_backend is None) - from vllm.v1.executor.uniproc_executor import UniprocExecutor - executor_class = UniprocExecutor - return executor_class - async def add_request( self, request_id: str, diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 0bd9b52c9be82..8ced3a34d2da3 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -89,7 +89,7 @@ def from_engine_args( # Create the engine configs. vllm_config = engine_args.create_engine_config(usage_context) - executor_class = cls._get_executor_cls(vllm_config) + executor_class = Executor.get_class(vllm_config) if VLLM_ENABLE_V1_MULTIPROCESSING: logger.debug("Enabling multiprocessing for LLMEngine.") @@ -103,24 +103,6 @@ def from_engine_args( stat_loggers=stat_loggers, multiprocess_mode=enable_multiprocessing) - @classmethod - def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]: - executor_class: Type[Executor] - distributed_executor_backend = ( - vllm_config.parallel_config.distributed_executor_backend) - if distributed_executor_backend == "ray": - from vllm.v1.executor.ray_executor import RayExecutor - executor_class = RayExecutor - elif distributed_executor_backend == "mp": - from vllm.v1.executor.multiproc_executor import MultiprocExecutor - executor_class = MultiprocExecutor - else: - assert (distributed_executor_backend is None) - from vllm.v1.executor.uniproc_executor import UniprocExecutor - executor_class = UniprocExecutor - - return executor_class - def get_num_unfinished_requests(self) -> int: return self.detokenizer.get_num_unfinished_requests() diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 564d0447f15a6..5d74d4b01f500 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Tuple +from typing import Tuple, Type from vllm.config import VllmConfig from vllm.v1.outputs import ModelRunnerOutput @@ -8,6 +8,23 @@ class Executor(ABC): """Abstract class for executors.""" + @staticmethod + def get_class(vllm_config: VllmConfig) -> Type["Executor"]: + executor_class: Type[Executor] + distributed_executor_backend = ( + vllm_config.parallel_config.distributed_executor_backend) + if distributed_executor_backend == "ray": + from vllm.v1.executor.ray_executor import RayExecutor + executor_class = RayExecutor + elif distributed_executor_backend == "mp": + from vllm.v1.executor.multiproc_executor import MultiprocExecutor + executor_class = MultiprocExecutor + else: + assert (distributed_executor_backend is None) + from vllm.v1.executor.uniproc_executor import UniprocExecutor + executor_class = UniprocExecutor + return executor_class + @abstractmethod def __init__(self, vllm_config: VllmConfig) -> None: raise NotImplementedError