From a799814cb86b43cb94013feb3b18f1166324bd29 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 8 Jul 2024 20:02:15 -0700 Subject: [PATCH] [hardware][cuda] use device id under CUDA_VISIBLE_DEVICES for get_device_capability (#6216) --- vllm/platforms/cuda.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index b2ca758131e92..2d482010cf760 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -2,6 +2,7 @@ pynvml. However, it should not initialize cuda context. """ +import os from functools import lru_cache, wraps from typing import Tuple @@ -23,12 +24,27 @@ def wrapper(*args, **kwargs): return wrapper +@lru_cache(maxsize=8) +@with_nvml_context +def get_physical_device_capability(device_id: int = 0) -> Tuple[int, int]: + handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) + return pynvml.nvmlDeviceGetCudaComputeCapability(handle) + + +def device_id_to_physical_device_id(device_id: int) -> int: + if "CUDA_VISIBLE_DEVICES" in os.environ: + device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") + device_ids = [int(device_id) for device_id in device_ids] + physical_device_id = device_ids[device_id] + else: + physical_device_id = device_id + return physical_device_id + + class CudaPlatform(Platform): _enum = PlatformEnum.CUDA @staticmethod - @lru_cache(maxsize=8) - @with_nvml_context def get_device_capability(device_id: int = 0) -> Tuple[int, int]: - handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) - return pynvml.nvmlDeviceGetCudaComputeCapability(handle) + physical_device_id = device_id_to_physical_device_id(device_id) + return get_physical_device_capability(physical_device_id)