From 4965726eceac50497f464dc431e18f33980a1183 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Tue, 10 Sep 2024 11:52:22 -0700 Subject: [PATCH 01/16] prefetch weights separately --- mcloud.yaml | 2 +- ultravox/training/helpers/prefetch_weights.py | 37 +++++++++++++++++++ ultravox/training/train.py | 14 +++---- 3 files changed, 45 insertions(+), 8 deletions(-) create mode 100644 ultravox/training/helpers/prefetch_weights.py diff --git a/mcloud.yaml b/mcloud.yaml index 99788954..427ac9f2 100644 --- a/mcloud.yaml +++ b/mcloud.yaml @@ -10,7 +10,7 @@ integrations: git_branch: $UV_BRANCH pip_install: poetry==1.7.1 command: >- - cd ultravox && poetry install --no-dev && poetry run torchrun --nproc_per_node=8 -m ultravox.training.train $TRAIN_ARGS + cd ultravox && poetry install --no-dev && poetry run python -m ultravox.training.helpers.prefetch_weights $TRAIN_ARGS && poetry run torchrun --nproc_per_node=8 -m ultravox.training.train $TRAIN_ARGS env_variables: MLFLOW_TRACKING_URI: databricks UV_BRANCH: main diff --git a/ultravox/training/helpers/prefetch_weights.py b/ultravox/training/helpers/prefetch_weights.py new file mode 100644 index 00000000..44a9bf30 --- /dev/null +++ b/ultravox/training/helpers/prefetch_weights.py @@ -0,0 +1,37 @@ +from datetime import datetime + +import huggingface_hub + +from ultravox.training import config_base +from ultravox.model import wandb_utils + +ALLOW_PATTERNS = ["*.safetensors", "*.json"] + + +def main(): + start = datetime.now() + print("Downloading weights ...") + + args = config_base.get_train_args() + + for model_id in [args.text_model, args.audio_model]: + try: + huggingface_hub.snapshot_download( + repo_id=model_id, allow_patterns=ALLOW_PATTERNS + ) + except huggingface_hub.utils.GatedRepoError as e: + raise e + except huggingface_hub.utils.RepositoryNotFoundError as e: + print( + f"Model {args.text_model} not found on HF Hub. Skipping download. Error: {e}" + ) + + if args.model_load_dir and wandb_utils.is_wandb_url(args.model_load_dir): + wandb_utils.download_model_from_wandb(args.model_load_dir) + + end = datetime.now() + print(f"Weights are downloaded in {end - start} seconds") + + +if __name__ == "__main__": + main() diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 05cea992..e6352959 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -129,10 +129,9 @@ def train(args: config_base.TrainConfig): logging.info("Instantiating model...") - # Since the model downloads the language model and audio encoder weights, we want one process to finish up - # downloading before the others start in order to avoid race conditions. - with ddp_utils.run_on_master_first(is_master): - model = ultravox_model.UltravoxModel(config) + # We assume that the weights are already downloaded via prefetch_weights.py + # If the weights are not downloaded, we might see a race condition here when using DDP. + model = ultravox_model.UltravoxModel(config) assert model.get_input_embeddings().num_embeddings == len( text_tokenizer @@ -166,9 +165,10 @@ def train(args: config_base.TrainConfig): logging.info(f"Loading model state dict from {args.model_load_dir}") load_path = args.model_load_dir if wandb_utils.is_wandb_url(load_path): - # Download the model from W&B. The main process should do the download while the others wait. - with ddp_utils.run_on_master_first(is_master): - load_path = wandb_utils.download_model_from_wandb(load_path) + # We assume that the weights are already downloaded via prefetch_weights.py + # and hence this is just resolving the path. If the weights are not downloaded, + # we might see a race condition here when using DDP. + load_path = wandb_utils.download_model_from_wandb(load_path) if os.path.isdir(load_path): load_path = os.path.join(load_path, "model*.safetensors") paths = glob.glob(load_path) From 83908520e3ea3777c990dde547a8ba10aecd14ae Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Tue, 10 Sep 2024 11:56:15 -0700 Subject: [PATCH 02/16] formatting --- ultravox/training/helpers/prefetch_weights.py | 2 +- ultravox/training/train.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/ultravox/training/helpers/prefetch_weights.py b/ultravox/training/helpers/prefetch_weights.py index 44a9bf30..ddc28355 100644 --- a/ultravox/training/helpers/prefetch_weights.py +++ b/ultravox/training/helpers/prefetch_weights.py @@ -2,8 +2,8 @@ import huggingface_hub -from ultravox.training import config_base from ultravox.model import wandb_utils +from ultravox.training import config_base ALLOW_PATTERNS = ["*.safetensors", "*.json"] diff --git a/ultravox/training/train.py b/ultravox/training/train.py index e6352959..424676d6 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -29,7 +29,6 @@ from ultravox.model import ultravox_processing from ultravox.model import wandb_utils from ultravox.training import config_base -from ultravox.training import ddp_utils INPUT_EXAMPLE = {"text": "Transcribe\n<|audio|>", "audio": b"\x00\x00" * 16000} OUTPUT_EXAMPLE = {"text": "Hello, world!"} From 2d472d7d10870cee793f31283587678416a6589e Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Tue, 10 Sep 2024 12:01:32 -0700 Subject: [PATCH 03/16] moving get_train_args to config --- ultravox/training/config_base.py | 15 +++++++++++++++ ultravox/training/train.py | 14 +------------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index 232c0584..f4e3f0f7 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -2,6 +2,8 @@ import datetime import logging import os +import re +import sys from pathlib import Path from typing import Any, Dict, List, Optional @@ -130,3 +132,16 @@ def __post_init__(self): "LayerDrop cannot be used in DDP when encoder is not frozen. Disabling LayerDrop." ) self.disable_layerdrop = True + + +def fix_hyphens(arg: str): + return re.sub(r"^--([^=]+)", lambda m: "--" + m.group(1).replace("-", "_"), arg) + + +def get_train_args() -> TrainConfig: + return simple_parsing.parse( + config_class=TrainConfig, + config_path="ultravox/training/configs/meta_config.yaml", # base config file + add_config_path_arg=True, + args=[fix_hyphens(arg) for arg in sys.argv[1:]], + ) diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 424676d6..e5f6d7be 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -4,16 +4,13 @@ import glob import logging import os -import re import subprocess -import sys from datetime import datetime from typing import Dict, List, Optional import datasets as hf_datasets import pandas as pd import safetensors.torch -import simple_parsing import torch import torch.distributed import transformers @@ -34,10 +31,6 @@ OUTPUT_EXAMPLE = {"text": "Hello, world!"} -def fix_hyphens(arg: str): - return re.sub(r"^--([^=]+)", lambda m: "--" + m.group(1).replace("-", "_"), arg) - - def prepare_dataset( train_args: config_base.TrainConfig, dataset_names: List[str], @@ -75,12 +68,7 @@ def main() -> None: os.environ["WANDB_LOG_MODEL"] = "checkpoint" os.environ["WANDB_PROJECT"] = "ultravox" - args = simple_parsing.parse( - config_class=config_base.TrainConfig, - config_path="ultravox/training/configs/meta_config.yaml", # base config file - add_config_path_arg=True, - args=[fix_hyphens(arg) for arg in sys.argv[1:]], - ) + args = config_base.get_train_args() transformers.set_seed(args.seed) From 008092f11aac800ce267cca44ff9477ada46b5da Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Tue, 10 Sep 2024 12:38:25 -0700 Subject: [PATCH 04/16] add a test for prefetch_weights --- ultravox/training/config_base.py | 16 +++++++++++++--- ultravox/training/helpers/prefetch_weights.py | 5 +++-- .../training/helpers/prefetch_weights_test.py | 13 +++++++++++++ 3 files changed, 29 insertions(+), 5 deletions(-) create mode 100644 ultravox/training/helpers/prefetch_weights_test.py diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index f4e3f0f7..3204e5a7 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -138,10 +138,20 @@ def fix_hyphens(arg: str): return re.sub(r"^--([^=]+)", lambda m: "--" + m.group(1).replace("-", "_"), arg) -def get_train_args() -> TrainConfig: +def get_train_args(args: Optional[List[str]] = None) -> TrainConfig: + """ + Parse the command line arguments and return a TrainConfig object. + + Args: + args: The command line arguments. If None, sys.argv[1:] is used. + This is mainly useful for testing. + """ + if args is None: + args = sys.argv[1:] + return simple_parsing.parse( config_class=TrainConfig, - config_path="ultravox/training/configs/meta_config.yaml", # base config file + config_path=os.path.join(os.path.dirname(__file__), "configs/meta_config.yaml"), add_config_path_arg=True, - args=[fix_hyphens(arg) for arg in sys.argv[1:]], + args=[fix_hyphens(arg) for arg in args], ) diff --git a/ultravox/training/helpers/prefetch_weights.py b/ultravox/training/helpers/prefetch_weights.py index ddc28355..271f4756 100644 --- a/ultravox/training/helpers/prefetch_weights.py +++ b/ultravox/training/helpers/prefetch_weights.py @@ -1,3 +1,4 @@ +from typing import List, Optional from datetime import datetime import huggingface_hub @@ -8,11 +9,11 @@ ALLOW_PATTERNS = ["*.safetensors", "*.json"] -def main(): +def main(args: Optional[List[str]] = None): start = datetime.now() print("Downloading weights ...") - args = config_base.get_train_args() + args = config_base.get_train_args(args) for model_id in [args.text_model, args.audio_model]: try: diff --git a/ultravox/training/helpers/prefetch_weights_test.py b/ultravox/training/helpers/prefetch_weights_test.py new file mode 100644 index 00000000..e6ce4f13 --- /dev/null +++ b/ultravox/training/helpers/prefetch_weights_test.py @@ -0,0 +1,13 @@ +import transformers +from ultravox.training.helpers import prefetch_weights + +TEXT_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM" +AUDIO_MODEL = "hf-internal-testing/tiny-random-WhisperForCausalLM" + + +def test_prefetch_weights(): + prefetch_weights.main(["--text-model", TEXT_MODEL, "--audio-model", AUDIO_MODEL]) + + # With local_files_only=True, from_pretrained will throw an error if the weights are not downloaded + transformers.AutoModel.from_pretrained(TEXT_MODEL, local_files_only=True) + transformers.AutoModel.from_pretrained(AUDIO_MODEL, local_files_only=True) From c45603f1965fcc2940906e2b772967a4f930175d Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Tue, 10 Sep 2024 13:15:17 -0700 Subject: [PATCH 05/16] formatting --- ultravox/training/helpers/prefetch_weights.py | 2 +- ultravox/training/helpers/prefetch_weights_test.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/ultravox/training/helpers/prefetch_weights.py b/ultravox/training/helpers/prefetch_weights.py index 271f4756..3e5ad19e 100644 --- a/ultravox/training/helpers/prefetch_weights.py +++ b/ultravox/training/helpers/prefetch_weights.py @@ -1,5 +1,5 @@ -from typing import List, Optional from datetime import datetime +from typing import List, Optional import huggingface_hub diff --git a/ultravox/training/helpers/prefetch_weights_test.py b/ultravox/training/helpers/prefetch_weights_test.py index e6ce4f13..19468c6b 100644 --- a/ultravox/training/helpers/prefetch_weights_test.py +++ b/ultravox/training/helpers/prefetch_weights_test.py @@ -1,4 +1,5 @@ import transformers + from ultravox.training.helpers import prefetch_weights TEXT_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM" From 8365fa554eb8d343e9354ca8e91a76f962c70a17 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Tue, 10 Sep 2024 13:19:15 -0700 Subject: [PATCH 06/16] fix name clash --- ultravox/training/config_base.py | 7 +++---- ultravox/training/helpers/prefetch_weights.py | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/ultravox/training/config_base.py b/ultravox/training/config_base.py index 3204e5a7..ec3ec58f 100644 --- a/ultravox/training/config_base.py +++ b/ultravox/training/config_base.py @@ -138,16 +138,15 @@ def fix_hyphens(arg: str): return re.sub(r"^--([^=]+)", lambda m: "--" + m.group(1).replace("-", "_"), arg) -def get_train_args(args: Optional[List[str]] = None) -> TrainConfig: +def get_train_args(override_sys_args: Optional[List[str]] = None) -> TrainConfig: """ Parse the command line arguments and return a TrainConfig object. Args: - args: The command line arguments. If None, sys.argv[1:] is used. + override_sys_args: The command line arguments. If None, sys.argv[1:] is used. This is mainly useful for testing. """ - if args is None: - args = sys.argv[1:] + args = override_sys_args or sys.argv[1:] return simple_parsing.parse( config_class=TrainConfig, diff --git a/ultravox/training/helpers/prefetch_weights.py b/ultravox/training/helpers/prefetch_weights.py index 3e5ad19e..83a6d660 100644 --- a/ultravox/training/helpers/prefetch_weights.py +++ b/ultravox/training/helpers/prefetch_weights.py @@ -9,11 +9,11 @@ ALLOW_PATTERNS = ["*.safetensors", "*.json"] -def main(args: Optional[List[str]] = None): +def main(override_sys_args: Optional[List[str]] = None): start = datetime.now() print("Downloading weights ...") - args = config_base.get_train_args(args) + args = config_base.get_train_args(override_sys_args) for model_id in [args.text_model, args.audio_model]: try: From 684b28138a318c82fb4979294e1dbda45367016a Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Wed, 11 Sep 2024 11:22:19 -0700 Subject: [PATCH 07/16] improved weight prefetching --- ultravox/training/helpers/prefetch_weights.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ultravox/training/helpers/prefetch_weights.py b/ultravox/training/helpers/prefetch_weights.py index 83a6d660..ef0068bc 100644 --- a/ultravox/training/helpers/prefetch_weights.py +++ b/ultravox/training/helpers/prefetch_weights.py @@ -2,6 +2,7 @@ from typing import List, Optional import huggingface_hub +import transformers from ultravox.model import wandb_utils from ultravox.training import config_base @@ -17,9 +18,13 @@ def main(override_sys_args: Optional[List[str]] = None): for model_id in [args.text_model, args.audio_model]: try: + # Download all model files that match ALLOW_PATTERNS + # This is faster than .from_pretrained due to parallel downloads huggingface_hub.snapshot_download( repo_id=model_id, allow_patterns=ALLOW_PATTERNS ) + # A backstop to make sure the model is fully downloaded even if ALLOW_PATTERNS is not enough + transformers.AutoModel.from_pretrained(model_id, device_map="meta") except huggingface_hub.utils.GatedRepoError as e: raise e except huggingface_hub.utils.RepositoryNotFoundError as e: From bb55ee46087c83692e79376052afb9a1dd9acb93 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Wed, 11 Sep 2024 11:23:44 -0700 Subject: [PATCH 08/16] comments --- ultravox/training/helpers/prefetch_weights.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ultravox/training/helpers/prefetch_weights.py b/ultravox/training/helpers/prefetch_weights.py index ef0068bc..6d43d935 100644 --- a/ultravox/training/helpers/prefetch_weights.py +++ b/ultravox/training/helpers/prefetch_weights.py @@ -24,6 +24,7 @@ def main(override_sys_args: Optional[List[str]] = None): repo_id=model_id, allow_patterns=ALLOW_PATTERNS ) # A backstop to make sure the model is fully downloaded even if ALLOW_PATTERNS is not enough + # Using `device_map="meta"` to avoid loading the weights into memory or device transformers.AutoModel.from_pretrained(model_id, device_map="meta") except huggingface_hub.utils.GatedRepoError as e: raise e From 7a818bc442a7b9d302b8d67e701df3d5b4acc87b Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Thu, 12 Sep 2024 11:00:22 -0700 Subject: [PATCH 09/16] moved from_pretrained call out of try/catch --- ultravox/training/helpers/prefetch_weights.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/ultravox/training/helpers/prefetch_weights.py b/ultravox/training/helpers/prefetch_weights.py index 6d43d935..a9284015 100644 --- a/ultravox/training/helpers/prefetch_weights.py +++ b/ultravox/training/helpers/prefetch_weights.py @@ -23,16 +23,21 @@ def main(override_sys_args: Optional[List[str]] = None): huggingface_hub.snapshot_download( repo_id=model_id, allow_patterns=ALLOW_PATTERNS ) - # A backstop to make sure the model is fully downloaded even if ALLOW_PATTERNS is not enough - # Using `device_map="meta"` to avoid loading the weights into memory or device - transformers.AutoModel.from_pretrained(model_id, device_map="meta") except huggingface_hub.utils.GatedRepoError as e: raise e except huggingface_hub.utils.RepositoryNotFoundError as e: + # We assume that the model is local if it's not found on HF Hub. + # The `.from_pretrained` call will verify the local case. print( f"Model {args.text_model} not found on HF Hub. Skipping download. Error: {e}" ) + # A backstop to make sure the model is fully downloaded. Scenarios to consider: + # - ALLOW_PATTERNS is not enough to download all files needed + # - The model is local + # Using `device_map="meta"` to avoid loading the weights into memory or device + transformers.AutoModel.from_pretrained(model_id, device_map="meta") + if args.model_load_dir and wandb_utils.is_wandb_url(args.model_load_dir): wandb_utils.download_model_from_wandb(args.model_load_dir) From e739633e774c74bcff797489caa0c09fcca0425f Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Thu, 12 Sep 2024 11:04:46 -0700 Subject: [PATCH 10/16] updated comments --- ultravox/training/helpers/prefetch_weights.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultravox/training/helpers/prefetch_weights.py b/ultravox/training/helpers/prefetch_weights.py index a9284015..8713ec64 100644 --- a/ultravox/training/helpers/prefetch_weights.py +++ b/ultravox/training/helpers/prefetch_weights.py @@ -34,7 +34,7 @@ def main(override_sys_args: Optional[List[str]] = None): # A backstop to make sure the model is fully downloaded. Scenarios to consider: # - ALLOW_PATTERNS is not enough to download all files needed - # - The model is local + # - The model is local, this will verify that everything is in order # Using `device_map="meta"` to avoid loading the weights into memory or device transformers.AutoModel.from_pretrained(model_id, device_map="meta") From f580653745df554b9a3d9e197f888de5c5df6b00 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Thu, 12 Sep 2024 12:29:49 -0700 Subject: [PATCH 11/16] mcli command -> multi-line --- mcloud.yaml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mcloud.yaml b/mcloud.yaml index 427ac9f2..1cbca1ce 100644 --- a/mcloud.yaml +++ b/mcloud.yaml @@ -10,7 +10,10 @@ integrations: git_branch: $UV_BRANCH pip_install: poetry==1.7.1 command: >- - cd ultravox && poetry install --no-dev && poetry run python -m ultravox.training.helpers.prefetch_weights $TRAIN_ARGS && poetry run torchrun --nproc_per_node=8 -m ultravox.training.train $TRAIN_ARGS + cd ultravox && + poetry install --no-dev && + poetry run python -m ultravox.training.helpers.prefetch_weights $TRAIN_ARGS && + poetry run torchrun --nproc_per_node=8 -m ultravox.training.train $TRAIN_ARGS env_variables: MLFLOW_TRACKING_URI: databricks UV_BRANCH: main From 794e423585e823a34ca0d590bc0fe202a11345e5 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Thu, 12 Sep 2024 14:03:46 -0700 Subject: [PATCH 12/16] raise error on ddp weights not downloaded --- ultravox/training/helpers/prefetch_weights.py | 11 +++++++++++ .../training/helpers/prefetch_weights_test.py | 6 ++++++ ultravox/training/train.py | 17 ++++++++++++++--- 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/ultravox/training/helpers/prefetch_weights.py b/ultravox/training/helpers/prefetch_weights.py index 8713ec64..71b3b553 100644 --- a/ultravox/training/helpers/prefetch_weights.py +++ b/ultravox/training/helpers/prefetch_weights.py @@ -45,5 +45,16 @@ def main(override_sys_args: Optional[List[str]] = None): print(f"Weights are downloaded in {end - start} seconds") +def raise_on_weights_not_downloaded(model_ids: List[str]): + """ + This is an imperfect check to see if the model weights are downloaded, + but it can catch if prefetch_weights.py was not run. + """ + for model_id in model_ids: + huggingface_hub.snapshot_download( + repo_id=model_id, allow_patterns=ALLOW_PATTERNS, local_files_only=True + ) + + if __name__ == "__main__": main() diff --git a/ultravox/training/helpers/prefetch_weights_test.py b/ultravox/training/helpers/prefetch_weights_test.py index 19468c6b..2d4a8086 100644 --- a/ultravox/training/helpers/prefetch_weights_test.py +++ b/ultravox/training/helpers/prefetch_weights_test.py @@ -7,8 +7,14 @@ def test_prefetch_weights(): + # It would be nice to test this, but there isn't an easy way to clear the cache + # with pytest.raises(huggingface_hub.utils.LocalEntryNotFoundError): + # prefetch_weights.raise_on_weights_not_downloaded([TEXT_MODEL, AUDIO_MODEL]) + prefetch_weights.main(["--text-model", TEXT_MODEL, "--audio-model", AUDIO_MODEL]) + prefetch_weights.raise_on_weights_not_downloaded([TEXT_MODEL, AUDIO_MODEL]) + # With local_files_only=True, from_pretrained will throw an error if the weights are not downloaded transformers.AutoModel.from_pretrained(TEXT_MODEL, local_files_only=True) transformers.AutoModel.from_pretrained(AUDIO_MODEL, local_files_only=True) diff --git a/ultravox/training/train.py b/ultravox/training/train.py index e5f6d7be..f8ece8ee 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -26,6 +26,7 @@ from ultravox.model import ultravox_processing from ultravox.model import wandb_utils from ultravox.training import config_base +from ultravox.training.helpers import prefetch_weights INPUT_EXAMPLE = {"text": "Transcribe\n<|audio|>", "audio": b"\x00\x00" * 16000} OUTPUT_EXAMPLE = {"text": "Hello, world!"} @@ -87,8 +88,9 @@ def train(args: config_base.TrainConfig): world_size = int(os.environ.get("WORLD_SIZE", 1)) local_rank = int(os.environ.get("LOCAL_RANK", 0)) is_master = local_rank == 0 + is_distributed = world_size > 1 - if world_size > 1: + if is_distributed: torch.distributed.init_process_group(backend="nccl") # DDP blows up logging, so this is an attempt to suppress it to only logs from the master process @@ -116,8 +118,17 @@ def train(args: config_base.TrainConfig): logging.info("Instantiating model...") - # We assume that the weights are already downloaded via prefetch_weights.py - # If the weights are not downloaded, we might see a race condition here when using DDP. + if is_distributed: + try: + prefetch_weights.raise_on_weights_not_downloaded( + [args.text_model, args.audio_model] + ) + except Exception as e: + # We assume that the weights are already downloaded via prefetch_weights.py + # If the weights are not downloaded, we might see a race condition here when using DDP/FSDP. + logging.error("Weights are not downloaded. Please run prefetch_weights.py.") + raise e + model = ultravox_model.UltravoxModel(config) assert model.get_input_embeddings().num_embeddings == len( From db1988ee7f4a8c210e6eb7f7a864663f33a65dbc Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Thu, 12 Sep 2024 16:04:51 -0700 Subject: [PATCH 13/16] improved logs string Co-authored-by: Justin Uberti --- ultravox/training/helpers/prefetch_weights.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultravox/training/helpers/prefetch_weights.py b/ultravox/training/helpers/prefetch_weights.py index 71b3b553..7073d764 100644 --- a/ultravox/training/helpers/prefetch_weights.py +++ b/ultravox/training/helpers/prefetch_weights.py @@ -42,7 +42,7 @@ def main(override_sys_args: Optional[List[str]] = None): wandb_utils.download_model_from_wandb(args.model_load_dir) end = datetime.now() - print(f"Weights are downloaded in {end - start} seconds") + print(f"Weights downloaded in {end - start} seconds") def raise_on_weights_not_downloaded(model_ids: List[str]): From 97593c7b0463e3a3bc0a74227bc7dcc9450d8fc6 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Fri, 13 Sep 2024 09:28:13 -0700 Subject: [PATCH 14/16] use from_pretrained instead of snapshot for testing that model exists --- ultravox/training/helpers/prefetch_weights.py | 8 ++++---- ultravox/training/helpers/prefetch_weights_test.py | 6 ------ 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/ultravox/training/helpers/prefetch_weights.py b/ultravox/training/helpers/prefetch_weights.py index 7073d764..1d16812d 100644 --- a/ultravox/training/helpers/prefetch_weights.py +++ b/ultravox/training/helpers/prefetch_weights.py @@ -47,12 +47,12 @@ def main(override_sys_args: Optional[List[str]] = None): def raise_on_weights_not_downloaded(model_ids: List[str]): """ - This is an imperfect check to see if the model weights are downloaded, - but it can catch if prefetch_weights.py was not run. + This function checks to see if the model weights are downloaded and available locally. + If they are not, it raises an error. """ for model_id in model_ids: - huggingface_hub.snapshot_download( - repo_id=model_id, allow_patterns=ALLOW_PATTERNS, local_files_only=True + transformers.AutoModel.from_pretrained( + model_id, device_map="meta", local_files_only=True ) diff --git a/ultravox/training/helpers/prefetch_weights_test.py b/ultravox/training/helpers/prefetch_weights_test.py index 2d4a8086..3e6f13f8 100644 --- a/ultravox/training/helpers/prefetch_weights_test.py +++ b/ultravox/training/helpers/prefetch_weights_test.py @@ -1,5 +1,3 @@ -import transformers - from ultravox.training.helpers import prefetch_weights TEXT_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM" @@ -14,7 +12,3 @@ def test_prefetch_weights(): prefetch_weights.main(["--text-model", TEXT_MODEL, "--audio-model", AUDIO_MODEL]) prefetch_weights.raise_on_weights_not_downloaded([TEXT_MODEL, AUDIO_MODEL]) - - # With local_files_only=True, from_pretrained will throw an error if the weights are not downloaded - transformers.AutoModel.from_pretrained(TEXT_MODEL, local_files_only=True) - transformers.AutoModel.from_pretrained(AUDIO_MODEL, local_files_only=True) From 77d023d0be7312122c5811a8aea7f7e27ee78292 Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Fri, 13 Sep 2024 09:58:39 -0700 Subject: [PATCH 15/16] moving raise_on_weights_not_downloaded before tokenizer init --- ultravox/training/train.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/ultravox/training/train.py b/ultravox/training/train.py index f8ece8ee..8870120a 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -90,15 +90,26 @@ def train(args: config_base.TrainConfig): is_master = local_rank == 0 is_distributed = world_size > 1 - if is_distributed: - torch.distributed.init_process_group(backend="nccl") - # DDP blows up logging, so this is an attempt to suppress it to only logs from the master process logging.basicConfig(level=logging.INFO if is_master else logging.ERROR) # os.environ["TORCH_LOGS"] = "ERROR" if is_master else "WARNING" transformers.logging.set_verbosity(logging.WARNING if is_master else logging.ERROR) hf_datasets.logging.set_verbosity(logging.WARNING if is_master else logging.ERROR) + if is_distributed: + torch.distributed.init_process_group(backend="nccl") + + # make sure the weights are downloaded before initializing the model in distributed mode + try: + prefetch_weights.raise_on_weights_not_downloaded( + [args.text_model, args.audio_model] + ) + except Exception as e: + # We assume that the weights are already downloaded via prefetch_weights.py + # If the weights are not downloaded, we might see a race condition here when using DDP/FSDP. + logging.error("Weights are not downloaded. Please run prefetch_weights.py.") + raise e + logging.info("Instantiating processor...") text_tokenizer: transformers.PreTrainedTokenizerFast = ( transformers.AutoTokenizer.from_pretrained(args.text_model) @@ -117,18 +128,6 @@ def train(args: config_base.TrainConfig): ) logging.info("Instantiating model...") - - if is_distributed: - try: - prefetch_weights.raise_on_weights_not_downloaded( - [args.text_model, args.audio_model] - ) - except Exception as e: - # We assume that the weights are already downloaded via prefetch_weights.py - # If the weights are not downloaded, we might see a race condition here when using DDP/FSDP. - logging.error("Weights are not downloaded. Please run prefetch_weights.py.") - raise e - model = ultravox_model.UltravoxModel(config) assert model.get_input_embeddings().num_embeddings == len( From 3a79974fe0408994d23ec946e7ac46df73b77b5c Mon Sep 17 00:00:00 2001 From: Farzad Abdolhosseini Date: Fri, 13 Sep 2024 15:35:18 -0700 Subject: [PATCH 16/16] double check for prefetch_weights for local/test runs --- ultravox/training/helpers/prefetch_weights.py | 29 +++++++------------ .../training/helpers/prefetch_weights_test.py | 10 +++---- ultravox/training/train.py | 18 +++++------- 3 files changed, 24 insertions(+), 33 deletions(-) diff --git a/ultravox/training/helpers/prefetch_weights.py b/ultravox/training/helpers/prefetch_weights.py index 1d16812d..30449a38 100644 --- a/ultravox/training/helpers/prefetch_weights.py +++ b/ultravox/training/helpers/prefetch_weights.py @@ -16,7 +16,14 @@ def main(override_sys_args: Optional[List[str]] = None): args = config_base.get_train_args(override_sys_args) - for model_id in [args.text_model, args.audio_model]: + download_weights([args.text_model, args.audio_model], args.model_load_dir) + + end = datetime.now() + print(f"Weights downloaded in {end - start} seconds") + + +def download_weights(model_ids: List[str], model_load_dir: Optional[str] = None): + for model_id in model_ids: try: # Download all model files that match ALLOW_PATTERNS # This is faster than .from_pretrained due to parallel downloads @@ -29,7 +36,7 @@ def main(override_sys_args: Optional[List[str]] = None): # We assume that the model is local if it's not found on HF Hub. # The `.from_pretrained` call will verify the local case. print( - f"Model {args.text_model} not found on HF Hub. Skipping download. Error: {e}" + f"Model {model_id} not found on HF Hub. Skipping download. Error: {e}" ) # A backstop to make sure the model is fully downloaded. Scenarios to consider: @@ -38,22 +45,8 @@ def main(override_sys_args: Optional[List[str]] = None): # Using `device_map="meta"` to avoid loading the weights into memory or device transformers.AutoModel.from_pretrained(model_id, device_map="meta") - if args.model_load_dir and wandb_utils.is_wandb_url(args.model_load_dir): - wandb_utils.download_model_from_wandb(args.model_load_dir) - - end = datetime.now() - print(f"Weights downloaded in {end - start} seconds") - - -def raise_on_weights_not_downloaded(model_ids: List[str]): - """ - This function checks to see if the model weights are downloaded and available locally. - If they are not, it raises an error. - """ - for model_id in model_ids: - transformers.AutoModel.from_pretrained( - model_id, device_map="meta", local_files_only=True - ) + if model_load_dir and wandb_utils.is_wandb_url(model_load_dir): + wandb_utils.download_model_from_wandb(model_load_dir) if __name__ == "__main__": diff --git a/ultravox/training/helpers/prefetch_weights_test.py b/ultravox/training/helpers/prefetch_weights_test.py index 3e6f13f8..19468c6b 100644 --- a/ultravox/training/helpers/prefetch_weights_test.py +++ b/ultravox/training/helpers/prefetch_weights_test.py @@ -1,3 +1,5 @@ +import transformers + from ultravox.training.helpers import prefetch_weights TEXT_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM" @@ -5,10 +7,8 @@ def test_prefetch_weights(): - # It would be nice to test this, but there isn't an easy way to clear the cache - # with pytest.raises(huggingface_hub.utils.LocalEntryNotFoundError): - # prefetch_weights.raise_on_weights_not_downloaded([TEXT_MODEL, AUDIO_MODEL]) - prefetch_weights.main(["--text-model", TEXT_MODEL, "--audio-model", AUDIO_MODEL]) - prefetch_weights.raise_on_weights_not_downloaded([TEXT_MODEL, AUDIO_MODEL]) + # With local_files_only=True, from_pretrained will throw an error if the weights are not downloaded + transformers.AutoModel.from_pretrained(TEXT_MODEL, local_files_only=True) + transformers.AutoModel.from_pretrained(AUDIO_MODEL, local_files_only=True) diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 8870120a..d74620e2 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -26,6 +26,7 @@ from ultravox.model import ultravox_processing from ultravox.model import wandb_utils from ultravox.training import config_base +from ultravox.training import ddp_utils from ultravox.training.helpers import prefetch_weights INPUT_EXAMPLE = {"text": "Transcribe\n<|audio|>", "audio": b"\x00\x00" * 16000} @@ -99,16 +100,13 @@ def train(args: config_base.TrainConfig): if is_distributed: torch.distributed.init_process_group(backend="nccl") - # make sure the weights are downloaded before initializing the model in distributed mode - try: - prefetch_weights.raise_on_weights_not_downloaded( - [args.text_model, args.audio_model] - ) - except Exception as e: - # We assume that the weights are already downloaded via prefetch_weights.py - # If the weights are not downloaded, we might see a race condition here when using DDP/FSDP. - logging.error("Weights are not downloaded. Please run prefetch_weights.py.") - raise e + with ddp_utils.run_on_master_first(is_master): + # For larger models, we assume that the weights are already downloaded via prefetch_weights.py + # Otherwise the barrier call can timeout. + # This call is only here as a backstop in case prefetch_weights.py was not run, for example in a local/test run. + prefetch_weights.download_weights( + [args.text_model, args.audio_model], args.model_load_dir + ) logging.info("Instantiating processor...") text_tokenizer: transformers.PreTrainedTokenizerFast = (