Skip to content

Commit

Permalink
HF Downloads (#21)
Browse files Browse the repository at this point in the history
* allow hf:// paths and add hf-transfer for faster hf downloads
  • Loading branch information
farzadab authored Dec 17, 2024
1 parent 31b489d commit c357138
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 16 deletions.
66 changes: 65 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ torchaudio = "^2.4.1"
scipy = "^1.14.1"
einops = "^0.8.0"
praatio = "^6.2.0"
hf-transfer = "^0.1.8"

[tool.poetry.group.dev.dependencies]
black = "~24.4.2"
Expand Down
31 changes: 31 additions & 0 deletions ultravox/model/hf_hub_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os

import huggingface_hub

ALLOW_PATTERNS = ["*.safetensors", "*.json", "*.txt"]


def is_hf_model(model_id: str) -> bool:
return model_id.startswith("hf://")


def get_hf_model_id(model_id: str) -> str:
if is_hf_model(model_id):
return model_id.split("hf://")[1]
return model_id


def download_hf_model(model_id: str, use_hf_transfer: bool = False) -> str:
"""
Download the model from HF Hub.
The model_id can be of format "hf://<repo_id>" to disambiguate from a local path, but <repo_id> is also accepted.
"""
model_id = get_hf_model_id(model_id)

if use_hf_transfer:
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

return huggingface_hub.snapshot_download(
repo_id=model_id, allow_patterns=ALLOW_PATTERNS
)
14 changes: 10 additions & 4 deletions ultravox/model/wandb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@ def get_artifact(model_url: str) -> wandb.Artifact:
def download_model_from_wandb(model_url: str) -> str:
artifact = get_artifact(model_url)

for file in artifact.files():
if not any(file.name.endswith(path) for path in IGNORE_PATHS):
print("downloading", file.name)
model_path = artifact.download(path_prefix=file.name)
if any(
file.name.endswith(path) for file in artifact.files() for path in IGNORE_PATHS
):
# downloading one by one to avoid downloading the ignored files
for file in artifact.files():
if not any(file.name.endswith(path) for path in IGNORE_PATHS):
print("downloading", file.name)
model_path = artifact.download(path_prefix=file.name)
else:
model_path = artifact.download()

if model_path is None:
raise ValueError(f"No files to be downloaded.")
Expand Down
16 changes: 6 additions & 10 deletions ultravox/training/helpers/prefetch_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import List, Optional

import huggingface_hub
import transformers

from ultravox.model import hf_hub_utils
from ultravox.model import wandb_utils
from ultravox.training import config_base

Expand Down Expand Up @@ -55,13 +55,15 @@ def download_weights(
if key in run_config:
model_ids.append(run_config[key])

if model_load_dir and hf_hub_utils.is_hf_model(model_load_dir):
model_ids.append(hf_hub_utils.get_hf_model_id(model_load_dir))

for model_id in model_ids:
try:
# Download all model files that match ALLOW_PATTERNS
# This is faster than .from_pretrained due to parallel downloads
huggingface_hub.snapshot_download(
repo_id=model_id, allow_patterns=ALLOW_PATTERNS
)
# We can also use hf-transfer to download the files which is faster on fast internet connections
hf_hub_utils.download_hf_model(model_id)
except huggingface_hub.utils.GatedRepoError as e:
raise e
except huggingface_hub.utils.RepositoryNotFoundError as e:
Expand All @@ -71,12 +73,6 @@ def download_weights(
f"Model {model_id} not found on HF Hub. Skipping download. Error: {e}"
)

# A backstop to make sure the model is fully downloaded. Scenarios to consider:
# - ALLOW_PATTERNS is not enough to download all files needed
# - The model is local, this will verify that everything is in order
# Using `device_map="meta"` to avoid loading the weights into memory or device
transformers.AutoModel.from_pretrained(model_id, device_map="meta")

return model_path


Expand Down
5 changes: 4 additions & 1 deletion ultravox/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import wandb.sdk

from ultravox import data as datasets
from ultravox.model import hf_hub_utils
from ultravox.model import wandb_utils
from ultravox.training import config_base
from ultravox.training import ddp_utils
Expand Down Expand Up @@ -131,11 +132,13 @@ def train(args: config_base.TrainConfig):
# and hence this is just resolving the path. If the weights are not downloaded,
# we might see a race condition here when using DDP.
load_path = wandb_utils.download_model_from_wandb(load_path)
elif hf_hub_utils.is_hf_model(load_path):
load_path = hf_hub_utils.download_hf_model(load_path)
if os.path.isdir(load_path):
load_path = os.path.join(load_path, "model*.safetensors")
paths = glob.glob(load_path)
assert len(paths) > 0, f"No model files found at {load_path}"
for path in glob.glob(load_path):
for path in paths:
state_dict = safetensors.torch.load_file(path)
mismatch = model.load_state_dict(state_dict, strict=False)
if mismatch.unexpected_keys:
Expand Down

0 comments on commit c357138

Please sign in to comment.