Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhongqiang Huang committed Sep 17, 2024
2 parents fcb81ff + be8ee6b commit 0c80894
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 42 deletions.
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ tensorboardx = "~2.6.2.2"
wandb = "~0.17.1"
sacrebleu = "^2.4.2"
tenacity = "^9.0.0"
evals = {git = "https://github.com/fixie-ai/evals"}
evals = {git = "https://github.com/fixie-ai/evals", rev = "0c66bf85df7a4b903ecb202b23c2a826b749fd71"}

[tool.poetry.group.dev.dependencies]
black = "~24.4.2"
Expand Down
53 changes: 36 additions & 17 deletions ultravox/model/ultravox_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):

config_class = UltravoxConfig
config: UltravoxConfig # for type hinting
_no_split_modules = ["Wav2Vec2Model", "WhisperEncoder", "LlamaDecoderLayer"]
# We minimize the weights in state_dict in order to reduce the size of the checkpoint
# The issue is that load_pretrained() uses state_dict() keys to know what keys are expected
# As such we have to tell is to ignore some keys that are not always in the model
Expand All @@ -46,14 +45,22 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):

def __init__(self, config: UltravoxConfig):
super().__init__(config)
self._register_load_state_dict_pre_hook(self._pre_load_state_dict_hook)

self.keep_params: Set[str] = set()
self.vocab_size = config.vocab_size

self.audio_tower = self._create_audio_tower(config)
self.multi_modal_projector = UltravoxProjector(config)
self.multi_modal_projector = self._create_multi_modal_projector(config)
self.language_model = self._create_language_model(config)

# Determine no_split_modules dynamically to use with FSDP auto_wrap policy.
# FSDP throws an error if some of the layer types are not found in the model.
# This would be something like ["LlamaDecoderLayer", "WhisperEncoderLayer"]
self._no_split_modules = (self.language_model._no_split_modules or []) + (
self.audio_tower._no_split_modules or []
)

self.loss_config = LossConfig()
self.post_init()

Expand Down Expand Up @@ -188,7 +195,7 @@ def forward(

# B x A/3200 x D
audio_tower_output = self.audio_tower.forward(
audio_values
audio_values.to(self.audio_tower.dtype)
).last_hidden_state
audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)

Expand Down Expand Up @@ -265,18 +272,26 @@ def prepare_inputs_for_generation(

return model_input

@classmethod
def _create_multi_modal_projector(
cls, config: UltravoxConfig
) -> "UltravoxProjector":
projector = UltravoxProjector(config)
projector.to(config.torch_dtype)
return projector

@classmethod
def _create_audio_tower(
cls, config: UltravoxConfig
) -> Union[transformers.Wav2Vec2Model, "ModifiedWhisperEncoder"]:
if config.audio_model_id is not None:
if "whisper" in config.audio_model_id is not None:
audio_tower = ModifiedWhisperEncoder.from_pretrained(
config.audio_model_id
config.audio_model_id, torch_dtype=config.torch_dtype
)
else:
audio_tower = transformers.AutoModel.from_pretrained(
config.audio_model_id
config.audio_model_id, torch_dtype=config.torch_dtype
)
else:
if "whisper" in config.audio_config._name_or_path:
Expand Down Expand Up @@ -307,14 +322,18 @@ def _create_language_model(
) -> transformers.LlamaForCausalLM:
if config.text_model_id is not None:
language_model = transformers.AutoModelForCausalLM.from_pretrained(
config.text_model_id, attn_implementation=config._attn_implementation
config.text_model_id,
attn_implementation=config._attn_implementation,
torch_dtype=config.torch_dtype,
)
else:
with transformers.modeling_utils.no_init_weights():
# we only ever use from_config if the weights are retrained, hence initializing is not
# required. This makes the model quite creation faster since init on CPU is quite slow.
language_model = transformers.AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
config.text_config,
attn_implementation=config._attn_implementation,
torch_dtype=config.torch_dtype,
)

language_model = apply_lora(language_model, config.text_model_lora_config)
Expand Down Expand Up @@ -356,26 +375,25 @@ def push_to_hub(self, *args, **kwargs):
self.to(self.language_model.dtype)
return super().push_to_hub(*args, **kwargs)

def state_dict(self, *args, **kwargs):
def save_pretrained(
self, *args, state_dict: Optional[Dict[str, Any]] = None, **kwargs
):
if state_dict is None:
state_dict = super().state_dict()

named_params = dict(self.named_parameters())
state_dict = super().state_dict(*args, **kwargs)

state_dict = {
k: v
for k, v in state_dict.items()
if k in self.keep_params
or (k in named_params and named_params[k].requires_grad)
}
return state_dict

def load_state_dict(
self,
state_dict: Dict[str, Any],
*args,
**kwargs,
):
super().save_pretrained(*args, state_dict=state_dict, **kwargs)

def _pre_load_state_dict_hook(self, state_dict: Dict[str, Any], *args, **kwargs):
self.keep_params.update(set(state_dict.keys()))
return super().load_state_dict(state_dict, *args, **kwargs)

def print_trainable_parameters(self):
"""
Expand Down Expand Up @@ -510,6 +528,7 @@ class ModifiedWhisperEncoder(whisper.WhisperEncoder):
"""

base_model_prefix = "model.encoder"
_no_split_modules = ["WhisperEncoderLayer"]

def forward(
self,
Expand Down
15 changes: 15 additions & 0 deletions ultravox/training/config_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ class TrainConfig:
data_type: str = device_helpers.default_dtype_str()
"""Data type to use for training (e.g., 'bfloat16', 'float16', 'float32')."""

use_fsdp: bool = False
"""Whether to use FSDP for distributed training."""

model_load_dir: Optional[str] = None
"""
Path to load pretrained ultravox model from. Can be local path, HF hub model_id, or W&B artifact.
Expand Down Expand Up @@ -217,6 +220,18 @@ def __post_init__(self):
)
self.disable_layerdrop = True

if self.use_fsdp and self.save_steps:
logging.warning(
"FSDP is enabled: Saving checkpoints is going to be extremely slow and results in a full save."
" Consider setting save_steps=0."
)

if self.use_fsdp and self.do_eval:
logging.warning(
"FSDP is enabled: Evaluation is not supported with FSDP. Disabling evaluation."
)
self.do_eval = False


def get_train_args(override_sys_args: Optional[List[str]] = None) -> TrainConfig:
"""
Expand Down
69 changes: 48 additions & 21 deletions ultravox/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ultravox.training import config_base
from ultravox.training import ddp_utils
from ultravox.training.helpers import prefetch_weights
from ultravox.utils import device_helpers

INPUT_EXAMPLE = {"text": "Transcribe\n<|audio|>", "audio": b"\x00\x00" * 16000}
OUTPUT_EXAMPLE = {"text": "Hello, world!"}
Expand Down Expand Up @@ -114,6 +115,8 @@ def train(args: config_base.TrainConfig):
text_model_id=args.text_model,
text_model_lora_config=args.text_model_lora_config,
audio_model_lora_config=args.audio_model_lora_config,
torch_dtype=args.data_type,
pad_token_id=text_tokenizer.eos_token_id,
)

logging.info("Instantiating model...")
Expand Down Expand Up @@ -175,13 +178,10 @@ def train(args: config_base.TrainConfig):

model.print_trainable_parameters()

# Move the model to GPU and enable bfloat16
dtype = getattr(torch, args.data_type)
device = torch.device(args.device, index=local_rank)
logging.info(
f"Using dtype and device (world_size): {dtype}, {device} ({world_size})"
)
model.to(device=device, dtype=dtype)
if not args.use_fsdp:
# Moving to device in FSDP is handled by the Trainer
model.to(device=torch.device(args.device, index=local_rank))
logging.info(f"Using device (world_size): {model.device} ({world_size})")

# Prepare dataset, subsetting if needed
train_dataset: data.IterableDataset
Expand Down Expand Up @@ -242,9 +242,9 @@ def train(args: config_base.TrainConfig):
optim=args.optimizer,
num_train_epochs=args.num_epochs,
max_steps=args.max_steps,
evaluation_strategy="steps",
eval_strategy="steps" if args.val_steps else "no",
eval_steps=args.val_steps,
save_strategy="steps",
save_strategy="steps" if args.save_steps else "no",
save_steps=args.save_steps,
logging_first_step=True,
logging_dir=args.logs_dir,
Expand All @@ -262,14 +262,19 @@ def train(args: config_base.TrainConfig):
lr_scheduler_type=args.lr_scheduler,
warmup_steps=args.lr_warmup_steps,
weight_decay=args.weight_decay,
fp16=dtype == torch.float16,
bf16=dtype == torch.bfloat16,
# fp16=dtype == torch.float16,
# bf16=dtype == torch.bfloat16,
use_cpu=args.device == "cpu",
seed=args.seed + local_rank,
report_to=args.report_logs_to,
# torch_compile=True,
# fsdp="full_shard auto_wrap",
# fsdp_transformer_layer_cls_to_wrap='LlamaDecoderLayer',
fsdp="full_shard auto_wrap" if args.use_fsdp else "",
fsdp_config={
"backward_prefetch": "backward_pre",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"state_dict_type": "SHARDED_STATE_DICT",
"sync_module_states": "true",
},
),
)

Expand All @@ -278,8 +283,15 @@ def train(args: config_base.TrainConfig):
logging.info("Starting training...")
t_start = datetime.now()
logging.info(f"train start time: {t_start}")

if args.val_steps:
trainer.evaluate()
if args.use_fsdp:
logging.warning(
"FSDP is enabled: Skipping initial validation since model is not initialized."
)
else:
trainer.evaluate()

trainer.train()
t_end = datetime.now()
logging.info(f"train end time: {t_end}")
Expand All @@ -300,7 +312,7 @@ def train(args: config_base.TrainConfig):
processor=processor,
tokenizer=text_tokenizer,
device=args.device,
dtype=dtype,
dtype=device_helpers.get_dtype(args.data_type),
)

metrics, output_files = eval.run_infer(
Expand All @@ -319,12 +331,27 @@ def train(args: config_base.TrainConfig):
logging.info(f"eval end time: {t_end}")
logging.info(f"elapsed: {t_end - t_start}")

if is_master:
# Saving the model using pipeline to ensure its code is saved
pipeline = ultravox_pipeline.UltravoxPipeline(
model, tokenizer=text_tokenizer, device=device
)
pipeline.save_pretrained(args.output_dir)
if args.use_fsdp:
# For training checkpoints, we want to use SHARDED_STATE_DICT which should be faster,
# but for the final save we want FULL_STATE_DICT so it can be serialized properly.
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")

# We use both pipeline.save_pretrained and trainer.save_model to save everything.
# This is because pipeline.save_pretrained knows how to save the pipeline (code and config),
# but it doesn't know how to save FSDP models correctly (the final tensors could be flattened).
# on the other hand, trainer.save_model knows how to save FSDP models correctly, but it won't save the pipeline.
# Saving FSDP models is already quite slow though, so we don't want to save the model twice.
pipeline = ultravox_pipeline.UltravoxPipeline(
model, tokenizer=text_tokenizer, device=model.device
)
old_save_pretrained = model.save_pretrained
model.save_pretrained = lambda *_, **__: None # type: ignore[method-assign]
# saves the pipeline code and populates the config
pipeline.save_pretrained(args.output_dir)
model.save_pretrained = old_save_pretrained # type: ignore[method-assign]

# saves the model weights correctly (FSDP or otherwise)
trainer.save_model(args.output_dir)


if __name__ == "__main__":
Expand Down

0 comments on commit 0c80894

Please sign in to comment.