Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
  • Loading branch information
tdoublep committed Jul 4, 2024
1 parent 81eef8a commit b040645
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 10 deletions.
4 changes: 4 additions & 0 deletions vllm/executor/multiproc_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ResultHandler, WorkerMonitor)
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.triton_utils import maybe_set_triton_cache_manager
from vllm.utils import (cuda_device_count_stateless,
error_on_invalid_device_count_status,
get_distributed_init_method, get_open_port,
Expand Down Expand Up @@ -42,6 +43,9 @@ def _init_executor(self) -> None:
if "OMP_NUM_THREADS" not in os.environ:
os.environ["OMP_NUM_THREADS"] = "1"

# workaround for https://github.com/vllm-project/vllm/issues/6103
maybe_set_triton_cache_manager()

assert world_size <= cuda_device_count_stateless(), (
"please set tensor_parallel_size to less than max local gpu count")

Expand Down
10 changes: 0 additions & 10 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
)


def maybe_set_triton_cache_manager(module: str) -> None:
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
if cache_manger != module:
os.environ["TRITON_CACHE_MANAGER"] = module


def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
device_name = torch.cuda.get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}"
Expand Down Expand Up @@ -434,10 +428,6 @@ def fused_experts(hidden_states: torch.Tensor,
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)

# workaround for https://github.com/vllm-project/vllm/issues/6103
maybe_set_triton_cache_manager(
"vllm.triton_utils.custom_cache_manager:CustomCacheManager")

if override_config:
config = override_config
else:
Expand Down
6 changes: 6 additions & 0 deletions vllm/triton_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from vllm.triton_utils.custom_cache_manager import (
maybe_set_triton_cache_manager)

__all__ = [
"maybe_set_triton_cache_manager",
]
12 changes: 12 additions & 0 deletions vllm/triton_utils/custom_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@
from triton.runtime.cache import (FileCacheManager, default_cache_dir,
default_dump_dir, default_override_dir)

from vllm.logger import init_logger

logger = init_logger(__name__)


def maybe_set_triton_cache_manager() -> None:
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
if cache_manger is None:
manager = "vllm.triton_utils.custom_cache_manager:CustomCacheManager"
logger.info("Setting Triton cache manager to: %s", manager)
os.environ["TRITON_CACHE_MANAGER"] = manager


class CustomCacheManager(FileCacheManager):

Expand Down

0 comments on commit b040645

Please sign in to comment.