Skip to content

Commit

Permalink
[BugFix] Don't scan entire cache dir when loading model (vllm-project…
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill authored and panf2333 committed Feb 18, 2025
1 parent c241a75 commit e9a7f2a
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
import huggingface_hub.constants
import numpy as np
import torch
from huggingface_hub import (HfFileSystem, hf_hub_download, scan_cache_dir,
snapshot_download)
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm

Expand Down Expand Up @@ -239,7 +238,8 @@ def download_weights_from_hf(
Returns:
str: The path to the downloaded model weights.
"""
if not huggingface_hub.constants.HF_HUB_OFFLINE:
local_only = huggingface_hub.constants.HF_HUB_OFFLINE
if not local_only:
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
Expand All @@ -255,7 +255,6 @@ def download_weights_from_hf(
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
start_size = scan_cache_dir().size_on_disk
start_time = time.perf_counter()
hf_folder = snapshot_download(
model_name_or_path,
Expand All @@ -264,13 +263,12 @@ def download_weights_from_hf(
cache_dir=cache_dir,
tqdm_class=DisabledTqdm,
revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
local_files_only=local_only,
)
end_time = time.perf_counter()
end_size = scan_cache_dir().size_on_disk
if end_size != start_size:
logger.info("Time took to download weights for %s: %.6f seconds",
model_name_or_path, end_time - start_time)
time_taken = time.perf_counter() - start_time
if time_taken > 0.5:
logger.info("Time spent downloading weights for %s: %.6f seconds",
model_name_or_path, time_taken)
return hf_folder


Expand Down

0 comments on commit e9a7f2a

Please sign in to comment.