diff --git a/Dockerfile.ubi b/Dockerfile.ubi index f64114408..ff7e0b187 100644 --- a/Dockerfile.ubi +++ b/Dockerfile.ubi @@ -270,19 +270,15 @@ RUN microdnf install -y gcc \ && microdnf clean all # patch triton (fix for #720) -COPY triton_patch/cache_fix.patch . -RUN microdnf install -y patch \ - && patch /opt/vllm/lib/python3.11/site-packages/triton/runtime/cache.py cache_fix.patch \ - && microdnf remove -y patch \ - && microdnf clean all \ - && rm cache_fix.patch +COPY triton_patch/custom_cache_manager.py /opt/vllm/lib/python3.11/site-packages/triton/runtime/custom_cache_manager.py ENV HF_HUB_OFFLINE=1 \ PORT=8000 \ GRPC_PORT=8033 \ HOME=/home/vllm \ VLLM_USAGE_SOURCE=production-docker-image \ - VLLM_WORKER_MULTIPROC_METHOD=fork + VLLM_WORKER_MULTIPROC_METHOD=fork \ + TRITON_CACHE_MANAGER="triton.runtime.custom_cache_manager:CustomCacheManager" # setup non-root user for OpenShift RUN microdnf install -y shadow-utils \ diff --git a/triton_patch/cache_fix.patch b/triton_patch/cache_fix.patch deleted file mode 100644 index 97a1aa477..000000000 --- a/triton_patch/cache_fix.patch +++ /dev/null @@ -1,8 +0,0 @@ -4c4 -< import random ---- -> import uuid -117c117 -< rnd_id = random.randint(0, 1000000) ---- -> rnd_id = str(uuid.uuid4()) diff --git a/triton_patch/custom_cache_manager.py b/triton_patch/custom_cache_manager.py new file mode 100644 index 000000000..5c27c072e --- /dev/null +++ b/triton_patch/custom_cache_manager.py @@ -0,0 +1,32 @@ +import os + +from triton.runtime.cache import (FileCacheManager, default_cache_dir, + default_dump_dir, default_override_dir) + + +class CustomCacheManager(FileCacheManager): + + def __init__(self, key, override=False, dump=False): + self.key = key + self.lock_path = None + if dump: + self.cache_dir = default_dump_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + elif override: + self.cache_dir = default_override_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + else: + # create cache directory if it doesn't exist + self.cache_dir = os.getenv("TRITON_CACHE_DIR", + "").strip() or default_cache_dir() + if self.cache_dir: + self.cache_dir = f"{self.cache_dir}_{os.getpid()}" + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + else: + raise RuntimeError("Could not create or locate cache dir") + + print(f"Triton cache dir: {self.cache_dir=}") \ No newline at end of file