diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 7442de245bd80..bff557d7fc92f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -107,7 +107,7 @@ steps: source_file_dependencies: - vllm/ commands: - - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py + - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py --ignore=entrypoints/llm/test_collective_rpc.py - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process @@ -466,7 +466,9 @@ steps: - vllm/worker/worker_base.py - vllm/worker/worker.py - vllm/worker/model_runner.py + - entrypoints/llm/test_collective_rpc.py commands: + - pytest -v -s entrypoints/llm/test_collective_rpc.py - torchrun --nproc-per-node=2 distributed/test_torchrun_example.py - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py diff --git a/tests/engine/test_custom_executor.py b/tests/engine/test_custom_executor.py index 2a057ca488a50..fdfcd4f4c9d50 100644 --- a/tests/engine/test_custom_executor.py +++ b/tests/engine/test_custom_executor.py @@ -1,6 +1,6 @@ import asyncio import os -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import pytest @@ -18,7 +18,7 @@ class Mock: class CustomUniExecutor(UniProcExecutor): def collective_rpc(self, - method: str, + method: Union[str, Callable], timeout: Optional[float] = None, args: Tuple = (), kwargs: Optional[Dict] = None) -> List[Any]: diff --git a/tests/entrypoints/llm/test_collective_rpc.py b/tests/entrypoints/llm/test_collective_rpc.py new file mode 100644 index 0000000000000..22473ce275295 --- /dev/null +++ b/tests/entrypoints/llm/test_collective_rpc.py @@ -0,0 +1,36 @@ +import pytest + +from vllm import LLM + +from ...utils import fork_new_process_for_each_test + + +@pytest.mark.parametrize("tp_size", [1, 2]) +@pytest.mark.parametrize("backend", ["mp", "ray"]) +@fork_new_process_for_each_test +def test_collective_rpc(tp_size, backend): + if tp_size == 1 and backend == "ray": + pytest.skip("Skip duplicate test case") + if tp_size == 1: + backend = None + + # intentionally define the method and class in the test function, + # to test if they can be serialized and sent to the workers + def echo_rank(self): + return self.rank + + from vllm.worker.worker import Worker + + class MyWorker(Worker): + + def echo_rank(self): + return self.rank + + llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + load_format="dummy", + tensor_parallel_size=tp_size, + distributed_executor_backend=backend, + worker_cls=MyWorker) + for method in ["echo_rank", echo_rank]: + assert llm.collective_rpc(method) == list(range(tp_size)) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 5d19ce03d5b58..88c21f9a6d31b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -5,10 +5,10 @@ from contextlib import contextmanager from dataclasses import dataclass from functools import partial -from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable, - List, Mapping, NamedTuple, Optional) +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, + Iterable, List, Mapping, NamedTuple, Optional) from typing import Sequence as GenericSequence -from typing import Set, Type, Union, cast, overload +from typing import Set, Tuple, Type, Union, cast, overload import torch from typing_extensions import TypeVar, deprecated @@ -1816,6 +1816,17 @@ def start_profile(self) -> None: def stop_profile(self) -> None: self.model_executor.stop_profile() + def collective_rpc(self, + method: Union[str, Callable], + timeout: Optional[float] = None, + args: Tuple = (), + kwargs: Optional[Dict] = None) -> List[Any]: + """ + See LLM.collective_rpc for more details. + """ + return self.model_executor.collective_rpc(method, timeout, args, + kwargs) + def check_health(self) -> None: if self.tokenizer: self.tokenizer.check_health() diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b78d5c65a40f8..0cfe6be9ac767 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,8 +1,8 @@ import itertools import warnings from contextlib import contextmanager -from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, - Union, cast, overload) +from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence, + Tuple, Type, Union, cast, overload) import cloudpickle from tqdm import tqdm @@ -464,7 +464,7 @@ def generate( return self.engine_class.validate_outputs(outputs, RequestOutput) def collective_rpc(self, - method: str, + method: Union[str, Callable], timeout: Optional[float] = None, args: Tuple = (), kwargs: Optional[Dict] = None) -> List[Any]: @@ -476,9 +476,13 @@ def collective_rpc(self, Then, users can call the new methods through this API. It is recommended to use this API to only pass control messages, and set up data-plane communication to pass data. + The method can also be a callable, which will be serialized + and sent to all workers to execute. + If the method is a callable, it should accept an additional + `self` argument, in addition to the arguments passed in `args` + and `kwargs`. The `self` argument will be the worker object. """ - return self.llm_engine.model_executor.collective_rpc( - method, timeout, args, kwargs) + return self.llm_engine.collective_rpc(method, timeout, args, kwargs) def beam_search( self, diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 00ecadcf92667..d8457cb693cdb 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -1,6 +1,7 @@ import asyncio from abc import ABC, abstractmethod -from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union +from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, + Union) from vllm.config import VllmConfig from vllm.logger import init_logger @@ -47,7 +48,7 @@ def _init_executor(self) -> None: @abstractmethod def collective_rpc(self, - method: str, + method: Union[str, Callable], timeout: Optional[float] = None, args: Tuple = (), kwargs: Optional[Dict] = None) -> List[Any]: @@ -260,7 +261,7 @@ def _driver_execute_model( raise NotImplementedError def collective_rpc(self, - method: str, + method: Union[str, Callable], timeout: Optional[float] = None, args: Tuple = (), kwargs: Optional[Dict] = None) -> List[Any]: @@ -269,7 +270,7 @@ def collective_rpc(self, @abstractmethod def _run_workers( self, - method: str, + method: Union[str, Callable], *args, async_run_tensor_parallel_workers_only: bool = False, max_concurrent_workers: Optional[int] = None, diff --git a/vllm/executor/mp_distributed_executor.py b/vllm/executor/mp_distributed_executor.py index d9dde949b844a..8ae88e646aad6 100644 --- a/vllm/executor/mp_distributed_executor.py +++ b/vllm/executor/mp_distributed_executor.py @@ -1,5 +1,7 @@ import asyncio -from typing import Any, List, Optional +from typing import Any, Callable, List, Optional, Union + +import cloudpickle from vllm.executor.executor_base import DistributedExecutorBase from vllm.executor.multiproc_worker_utils import ( @@ -9,7 +11,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest from vllm.utils import (_run_task_with_lock, get_distributed_init_method, - get_ip, get_open_port, make_async) + get_ip, get_open_port, make_async, run_method) from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -107,7 +109,7 @@ def _driver_execute_model( def _run_workers( self, - method: str, + method: Union[str, Callable], *args, async_run_tensor_parallel_workers_only: bool = False, max_concurrent_workers: Optional[int] = None, @@ -121,6 +123,11 @@ def _run_workers( It will also be run asynchronously and return a list of futures rather than blocking on the results. """ + if isinstance(method, str): + sent_method = method + else: + sent_method = cloudpickle.dumps(method) + del method if max_concurrent_workers: raise NotImplementedError( @@ -129,18 +136,18 @@ def _run_workers( if async_run_tensor_parallel_workers_only: # Run only non-driver workers and just return futures. return [ - worker.execute_method(method, *args, **kwargs) + worker.execute_method(sent_method, *args, **kwargs) for worker in self.non_driver_workers ] # Start all remote workers first. worker_outputs = [ - worker.execute_method(method, *args, **kwargs) + worker.execute_method(sent_method, *args, **kwargs) for worker in self.workers ] - driver_worker_method = getattr(self.driver_worker, method) - driver_worker_output = driver_worker_method(*args, **kwargs) + driver_worker_output = run_method(self.driver_worker, sent_method, + args, kwargs) # Get the results of the workers. return [driver_worker_output diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py index c9fb3c664c575..539b6ae2d3572 100644 --- a/vllm/executor/multiproc_worker_utils.py +++ b/vllm/executor/multiproc_worker_utils.py @@ -15,7 +15,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.triton_utils.importing import HAS_TRITON -from vllm.utils import _check_multiproc_method, get_mp_context +from vllm.utils import _check_multiproc_method, get_mp_context, run_method if HAS_TRITON: from vllm.triton_utils import maybe_set_triton_cache_manager @@ -169,7 +169,7 @@ def __init__(self, result_handler: ResultHandler, self.process.start() def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], - method: str, args, kwargs): + method: Union[str, bytes], args, kwargs): task_id = uuid.uuid4() self.tasks[task_id] = future try: @@ -180,12 +180,13 @@ def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], del self.tasks[task_id] raise ChildProcessError("worker died") from e - def execute_method(self, method: str, *args, **kwargs): + def execute_method(self, method: Union[str, bytes], *args, **kwargs): future: ResultFuture = ResultFuture() self._enqueue_task(future, method, args, kwargs) return future - async def execute_method_async(self, method: str, *args, **kwargs): + async def execute_method_async(self, method: Union[str, bytes], *args, + **kwargs): future = asyncio.get_running_loop().create_future() self._enqueue_task(future, method, args, kwargs) return await future @@ -230,8 +231,7 @@ def _run_worker_process( exception = None task_id, method, args, kwargs = items try: - executor = getattr(worker, method) - output = executor(*args, **kwargs) + output = run_method(worker, method, args, kwargs) except SystemExit: raise except KeyboardInterrupt: diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 3baeb63918a62..2afd99f99b353 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -2,8 +2,9 @@ import os from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +import cloudpickle import msgspec import vllm.envs as envs @@ -410,7 +411,7 @@ def execute_model( def _run_workers( self, - method: str, + method: Union[str, Callable], *args, async_run_tensor_parallel_workers_only: bool = False, max_concurrent_workers: Optional[int] = None, @@ -426,6 +427,11 @@ def _run_workers( rather than blocking on the results. - args/kwargs: All workers share the same args/kwargs """ + if isinstance(method, str): + sent_method = method + else: + sent_method = cloudpickle.dumps(method) + del method if self.use_ray_spmd_worker: assert not async_run_tensor_parallel_workers_only, ( "async_run_tensor_parallel_workers_only is not supported for " @@ -440,7 +446,7 @@ def _run_workers( if async_run_tensor_parallel_workers_only: ray_workers = self.non_driver_workers ray_worker_outputs = [ - worker.execute_method.remote(method, *args, **kwargs) + worker.execute_method.remote(sent_method, *args, **kwargs) for worker in ray_workers ] @@ -455,7 +461,7 @@ def _run_workers( if not self.use_ray_spmd_worker: # Start the driver worker after all the ray workers. driver_worker_output = [ - self.driver_worker.execute_method(method, *args, **kwargs) + self.driver_worker.execute_method(sent_method, *args, **kwargs) ] # Get the results of the ray workers. diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index 27b83e95ba95b..a5c4dcf0ec7f9 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -7,7 +7,8 @@ import vllm.envs as envs from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger -from vllm.utils import get_distributed_init_method, get_ip, get_open_port +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + run_method) from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -39,18 +40,13 @@ def _init_executor(self) -> None: self.collective_rpc("load_model") def collective_rpc(self, - method: str, + method: Union[str, Callable], timeout: Optional[float] = None, args: Tuple = (), kwargs: Optional[Dict] = None) -> List[Any]: if kwargs is None: kwargs = {} - try: - func = getattr(self.driver_worker, method) - except AttributeError: - raise NotImplementedError(f"Method {method} is not implemented.") \ - from None - answer = func(*args, **kwargs) + answer = run_method(self.driver_worker, method, args, kwargs) return [answer] def check_health(self) -> None: diff --git a/vllm/utils.py b/vllm/utils.py index 7477e7028f5ef..89ba119bb5e55 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -36,6 +36,7 @@ overload) from uuid import uuid4 +import cloudpickle import numpy as np import numpy.typing as npt import psutil @@ -2166,3 +2167,25 @@ def bind_kv_cache( assert len(forward_ctx.kv_cache) == len(kv_cache) for ve, ve_kv_cache in enumerate(kv_cache): forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx] + + +def run_method(obj: Any, method: Union[str, bytes, Callable], args: Tuple[Any], + kwargs: Dict[str, Any]) -> Any: + """ + Run a method of an object with the given arguments and keyword arguments. + If the method is string, it will be converted to a method using getattr. + If the method is serialized bytes and will be deserialized using + cloudpickle. + If the method is a callable, it will be called directly. + """ + if isinstance(method, bytes): + func = partial(cloudpickle.loads(method), obj) + elif isinstance(method, str): + try: + func = getattr(obj, method) + except AttributeError: + raise NotImplementedError(f"Method {method!r} is not" + " implemented.") from None + else: + func = partial(method, obj) # type: ignore + return func(*args, **kwargs) diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index e92acc7cb5e41..fd977d07e8d81 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -6,9 +6,11 @@ import weakref from dataclasses import dataclass from enum import Enum, auto +from functools import partial from multiprocessing.process import BaseProcess -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import cloudpickle import psutil import zmq @@ -120,7 +122,7 @@ def get_kv_cache_spec(self) -> KVCacheSpec: return kv_cache_specs[0] def collective_rpc(self, - method: str, + method: Union[str, Callable], timeout: Optional[float] = None, args: Tuple = (), kwargs: Optional[Dict] = None) -> List[Any]: @@ -141,7 +143,12 @@ def collective_rpc(self, kwargs = kwargs or {} try: - self.rpc_broadcast_mq.enqueue((method, args, kwargs)) + if isinstance(method, str): + send_method = method + else: + send_method = cloudpickle.dumps( + method, protocol=pickle.HIGHEST_PROTOCOL) + self.rpc_broadcast_mq.enqueue((send_method, args, kwargs)) responses = [None] * self.world_size for w in self.workers: @@ -408,7 +415,11 @@ def worker_busy_loop(self): method, args, kwargs = self.rpc_broadcast_mq.dequeue() try: - output = getattr(self.worker, method)(*args, **kwargs) + if isinstance(method, str): + func = getattr(self.worker, method) + elif isinstance(method, bytes): + func = partial(cloudpickle.loads(method), self.worker) + output = func(*args, **kwargs) except Exception as e: self.worker_response_mq.enqueue( (WorkerProc.ResponseStatus.FAILURE, e)) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index bced5b9f44228..fb9919f7a7b6a 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -14,7 +14,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import (enable_trace_function_call_for_thread, - resolve_obj_by_qualname, update_environment_variables) + resolve_obj_by_qualname, run_method, + update_environment_variables) from vllm.worker.model_runner_base import (BroadcastableModelInput, ModelRunnerBase, ModelRunnerInputBase) @@ -539,17 +540,16 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: self.worker = worker_class(**kwargs) assert self.worker is not None - def execute_method(self, method: str, *args, **kwargs): + def execute_method(self, method: Union[str, bytes], *args, **kwargs): try: target = self if self.worker is None else self.worker - executor = getattr(target, method) - return executor(*args, **kwargs) + return run_method(target, method, args, kwargs) except Exception as e: # if the driver worker also execute methods, # exceptions in the rest worker may cause deadlock in rpc like ray # see https://github.com/vllm-project/vllm/issues/3455 # print the error and inform the user to solve the error - msg = (f"Error executing method {method}. " + msg = (f"Error executing method {method!r}. " "This might cause deadlock in distributed execution.") logger.exception(msg) raise e