From a35456f1e6715940d45362da64cf59afaec9028b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 17 Jan 2025 21:14:22 +0800 Subject: [PATCH 1/2] fix mp check Signed-off-by: youkaichao --- vllm/executor/mp_distributed_executor.py | 36 ++++++++++++++++++++++-- vllm/platforms/cuda.py | 22 --------------- 2 files changed, 34 insertions(+), 24 deletions(-) diff --git a/vllm/executor/mp_distributed_executor.py b/vllm/executor/mp_distributed_executor.py index 8ae88e646aad6..dd32c8e21d9a5 100644 --- a/vllm/executor/mp_distributed_executor.py +++ b/vllm/executor/mp_distributed_executor.py @@ -1,4 +1,5 @@ import asyncio +import os from typing import Any, Callable, List, Optional, Union import cloudpickle @@ -10,8 +11,9 @@ from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest -from vllm.utils import (_run_task_with_lock, get_distributed_init_method, - get_ip, get_open_port, make_async, run_method) +from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless, + get_distributed_init_method, get_ip, get_open_port, + make_async, run_method, update_environment_variables) from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -22,7 +24,37 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase): uses_ray: bool = False + def _check_cuda(self) -> None: + """Check that the number of GPUs is sufficient for the parallel + configuration. Separate from _init_executor to reduce the number of + indented blocks. + """ + parallel_config = self.parallel_config + world_size = parallel_config.world_size + tensor_parallel_size = parallel_config.tensor_parallel_size + + cuda_device_count = cuda_device_count_stateless() + # Use confusing message for more common TP-only case. + assert tensor_parallel_size <= cuda_device_count, ( + f"please set tensor_parallel_size ({tensor_parallel_size}) " + f"to less than max local gpu count ({cuda_device_count})") + + assert world_size <= cuda_device_count, ( + f"please ensure that world_size ({world_size}) " + f"is less than than max local gpu count ({cuda_device_count})") + + # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers + if "CUDA_VISIBLE_DEVICES" not in os.environ: + update_environment_variables({ + "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size)))) + }) + def _init_executor(self) -> None: + + from vllm.platforms import current_platform + if current_platform.is_cuda_alike(): + self._check_cuda() + # Create the parallel GPU workers. world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 8350177b68ade..2587e3a11dde3 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -139,28 +139,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: else: parallel_config.worker_cls = "vllm.worker.worker.Worker" - world_size = parallel_config.world_size - tensor_parallel_size = parallel_config.tensor_parallel_size - - from vllm.utils import (cuda_device_count_stateless, - update_environment_variables) - - # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers - if "CUDA_VISIBLE_DEVICES" not in os.environ: - update_environment_variables({ - "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size)))) - }) - - cuda_device_count = cuda_device_count_stateless() - # Use confusing message for more common TP-only case. - assert tensor_parallel_size <= cuda_device_count, ( - f"please set tensor_parallel_size ({tensor_parallel_size}) " - f"to less than max local gpu count ({cuda_device_count})") - - assert world_size <= cuda_device_count, ( - f"please ensure that world_size ({world_size}) " - f"is less than than max local gpu count ({cuda_device_count})") - cache_config = vllm_config.cache_config if cache_config and cache_config.block_size is None: cache_config.block_size = 16 From 421dcb44c5e5136f0b800e0704374ab2d51a1b89 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 18 Jan 2025 10:49:38 +0800 Subject: [PATCH 2/2] use runtime error Signed-off-by: youkaichao --- vllm/executor/mp_distributed_executor.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/vllm/executor/mp_distributed_executor.py b/vllm/executor/mp_distributed_executor.py index dd32c8e21d9a5..a80b0ee8b3122 100644 --- a/vllm/executor/mp_distributed_executor.py +++ b/vllm/executor/mp_distributed_executor.py @@ -35,13 +35,15 @@ def _check_cuda(self) -> None: cuda_device_count = cuda_device_count_stateless() # Use confusing message for more common TP-only case. - assert tensor_parallel_size <= cuda_device_count, ( - f"please set tensor_parallel_size ({tensor_parallel_size}) " - f"to less than max local gpu count ({cuda_device_count})") - - assert world_size <= cuda_device_count, ( - f"please ensure that world_size ({world_size}) " - f"is less than than max local gpu count ({cuda_device_count})") + if tensor_parallel_size > cuda_device_count: + raise RuntimeError( + f"please set tensor_parallel_size ({tensor_parallel_size}) " + f"to less than max local gpu count ({cuda_device_count})") + + if world_size > cuda_device_count: + raise RuntimeError( + f"please ensure that world_size ({world_size}) " + f"is less than than max local gpu count ({cuda_device_count})") # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers if "CUDA_VISIBLE_DEVICES" not in os.environ: