Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[70B-Part1] Prefetch weights separately #106

Merged
merged 16 commits into from
Sep 13, 2024
Merged
5 changes: 4 additions & 1 deletion mcloud.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ integrations:
git_branch: $UV_BRANCH
pip_install: poetry==1.7.1
command: >-
cd ultravox && poetry install --no-dev && poetry run torchrun --nproc_per_node=8 -m ultravox.training.train $TRAIN_ARGS
cd ultravox &&
poetry install --no-dev &&
poetry run python -m ultravox.training.helpers.prefetch_weights $TRAIN_ARGS &&
poetry run torchrun --nproc_per_node=8 -m ultravox.training.train $TRAIN_ARGS
env_variables:
MLFLOW_TRACKING_URI: databricks
UV_BRANCH: main
Expand Down
24 changes: 24 additions & 0 deletions ultravox/training/config_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import datetime
import logging
import os
import re
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -130,3 +132,25 @@ def __post_init__(self):
"LayerDrop cannot be used in DDP when encoder is not frozen. Disabling LayerDrop."
)
self.disable_layerdrop = True


def fix_hyphens(arg: str):
return re.sub(r"^--([^=]+)", lambda m: "--" + m.group(1).replace("-", "_"), arg)


def get_train_args(override_sys_args: Optional[List[str]] = None) -> TrainConfig:
"""
Parse the command line arguments and return a TrainConfig object.

Args:
override_sys_args: The command line arguments. If None, sys.argv[1:] is used.
This is mainly useful for testing.
"""
args = override_sys_args or sys.argv[1:]

return simple_parsing.parse(
config_class=TrainConfig,
config_path=os.path.join(os.path.dirname(__file__), "configs/meta_config.yaml"),
add_config_path_arg=True,
args=[fix_hyphens(arg) for arg in args],
)
60 changes: 60 additions & 0 deletions ultravox/training/helpers/prefetch_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from datetime import datetime
from typing import List, Optional

import huggingface_hub
import transformers

from ultravox.model import wandb_utils
from ultravox.training import config_base

ALLOW_PATTERNS = ["*.safetensors", "*.json"]


def main(override_sys_args: Optional[List[str]] = None):
start = datetime.now()
print("Downloading weights ...")

args = config_base.get_train_args(override_sys_args)

for model_id in [args.text_model, args.audio_model]:
try:
# Download all model files that match ALLOW_PATTERNS
# This is faster than .from_pretrained due to parallel downloads
huggingface_hub.snapshot_download(
farzadab marked this conversation as resolved.
Show resolved Hide resolved
repo_id=model_id, allow_patterns=ALLOW_PATTERNS
)
except huggingface_hub.utils.GatedRepoError as e:
raise e
except huggingface_hub.utils.RepositoryNotFoundError as e:
# We assume that the model is local if it's not found on HF Hub.
# The `.from_pretrained` call will verify the local case.
print(
f"Model {args.text_model} not found on HF Hub. Skipping download. Error: {e}"
)
farzadab marked this conversation as resolved.
Show resolved Hide resolved

# A backstop to make sure the model is fully downloaded. Scenarios to consider:
# - ALLOW_PATTERNS is not enough to download all files needed
# - The model is local, this will verify that everything is in order
# Using `device_map="meta"` to avoid loading the weights into memory or device
transformers.AutoModel.from_pretrained(model_id, device_map="meta")

if args.model_load_dir and wandb_utils.is_wandb_url(args.model_load_dir):
farzadab marked this conversation as resolved.
Show resolved Hide resolved
wandb_utils.download_model_from_wandb(args.model_load_dir)

end = datetime.now()
print(f"Weights are downloaded in {end - start} seconds")
farzadab marked this conversation as resolved.
Show resolved Hide resolved


def raise_on_weights_not_downloaded(model_ids: List[str]):
"""
This is an imperfect check to see if the model weights are downloaded,
but it can catch if prefetch_weights.py was not run.
"""
for model_id in model_ids:
huggingface_hub.snapshot_download(
repo_id=model_id, allow_patterns=ALLOW_PATTERNS, local_files_only=True
)


if __name__ == "__main__":
main()
20 changes: 20 additions & 0 deletions ultravox/training/helpers/prefetch_weights_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import transformers

from ultravox.training.helpers import prefetch_weights

TEXT_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM"
AUDIO_MODEL = "hf-internal-testing/tiny-random-WhisperForCausalLM"


def test_prefetch_weights():
# It would be nice to test this, but there isn't an easy way to clear the cache
# with pytest.raises(huggingface_hub.utils.LocalEntryNotFoundError):
# prefetch_weights.raise_on_weights_not_downloaded([TEXT_MODEL, AUDIO_MODEL])

prefetch_weights.main(["--text-model", TEXT_MODEL, "--audio-model", AUDIO_MODEL])

prefetch_weights.raise_on_weights_not_downloaded([TEXT_MODEL, AUDIO_MODEL])

# With local_files_only=True, from_pretrained will throw an error if the weights are not downloaded
transformers.AutoModel.from_pretrained(TEXT_MODEL, local_files_only=True)
transformers.AutoModel.from_pretrained(AUDIO_MODEL, local_files_only=True)
42 changes: 20 additions & 22 deletions ultravox/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,13 @@
import glob
import logging
import os
import re
import subprocess
import sys
from datetime import datetime
from typing import Dict, List, Optional

import datasets as hf_datasets
import pandas as pd
import safetensors.torch
import simple_parsing
import torch
import torch.distributed
import transformers
Expand All @@ -29,16 +26,12 @@
from ultravox.model import ultravox_processing
from ultravox.model import wandb_utils
from ultravox.training import config_base
from ultravox.training import ddp_utils
from ultravox.training.helpers import prefetch_weights

INPUT_EXAMPLE = {"text": "Transcribe\n<|audio|>", "audio": b"\x00\x00" * 16000}
OUTPUT_EXAMPLE = {"text": "Hello, world!"}


def fix_hyphens(arg: str):
return re.sub(r"^--([^=]+)", lambda m: "--" + m.group(1).replace("-", "_"), arg)


def prepare_dataset(
train_args: config_base.TrainConfig,
dataset_names: List[str],
Expand Down Expand Up @@ -76,12 +69,7 @@ def main() -> None:
os.environ["WANDB_LOG_MODEL"] = "checkpoint"
os.environ["WANDB_PROJECT"] = "ultravox"

args = simple_parsing.parse(
config_class=config_base.TrainConfig,
config_path="ultravox/training/configs/meta_config.yaml", # base config file
add_config_path_arg=True,
args=[fix_hyphens(arg) for arg in sys.argv[1:]],
)
args = config_base.get_train_args()

transformers.set_seed(args.seed)

Expand All @@ -100,8 +88,9 @@ def train(args: config_base.TrainConfig):
world_size = int(os.environ.get("WORLD_SIZE", 1))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
is_master = local_rank == 0
is_distributed = world_size > 1

if world_size > 1:
if is_distributed:
torch.distributed.init_process_group(backend="nccl")

# DDP blows up logging, so this is an attempt to suppress it to only logs from the master process
Expand Down Expand Up @@ -129,10 +118,18 @@ def train(args: config_base.TrainConfig):

logging.info("Instantiating model...")

# Since the model downloads the language model and audio encoder weights, we want one process to finish up
# downloading before the others start in order to avoid race conditions.
with ddp_utils.run_on_master_first(is_master):
model = ultravox_model.UltravoxModel(config)
if is_distributed:
farzadab marked this conversation as resolved.
Show resolved Hide resolved
try:
prefetch_weights.raise_on_weights_not_downloaded(
[args.text_model, args.audio_model]
)
except Exception as e:
# We assume that the weights are already downloaded via prefetch_weights.py
# If the weights are not downloaded, we might see a race condition here when using DDP/FSDP.
logging.error("Weights are not downloaded. Please run prefetch_weights.py.")
raise e

model = ultravox_model.UltravoxModel(config)
farzadab marked this conversation as resolved.
Show resolved Hide resolved

assert model.get_input_embeddings().num_embeddings == len(
text_tokenizer
Expand Down Expand Up @@ -166,9 +163,10 @@ def train(args: config_base.TrainConfig):
logging.info(f"Loading model state dict from {args.model_load_dir}")
load_path = args.model_load_dir
if wandb_utils.is_wandb_url(load_path):
# Download the model from W&B. The main process should do the download while the others wait.
with ddp_utils.run_on_master_first(is_master):
load_path = wandb_utils.download_model_from_wandb(load_path)
# We assume that the weights are already downloaded via prefetch_weights.py
# and hence this is just resolving the path. If the weights are not downloaded,
# we might see a race condition here when using DDP.
load_path = wandb_utils.download_model_from_wandb(load_path)
farzadab marked this conversation as resolved.
Show resolved Hide resolved
if os.path.isdir(load_path):
load_path = os.path.join(load_path, "model*.safetensors")
paths = glob.glob(load_path)
Expand Down
Loading