diff --git a/ultravox/model/ultravox_model.py b/ultravox/model/ultravox_model.py index cb209888..89b01aca 100644 --- a/ultravox/model/ultravox_model.py +++ b/ultravox/model/ultravox_model.py @@ -34,7 +34,6 @@ class UltravoxModel(transformers.LlamaPreTrainedModel): config_class = UltravoxConfig config: UltravoxConfig # for type hinting - _no_split_modules = ["Wav2Vec2Model", "WhisperEncoder", "LlamaDecoderLayer"] # We minimize the weights in state_dict in order to reduce the size of the checkpoint # The issue is that load_pretrained() uses state_dict() keys to know what keys are expected # As such we have to tell is to ignore some keys that are not always in the model @@ -46,6 +45,7 @@ class UltravoxModel(transformers.LlamaPreTrainedModel): def __init__(self, config: UltravoxConfig): super().__init__(config) + self._register_load_state_dict_pre_hook(self._pre_load_state_dict_hook) self.keep_params: Set[str] = set() self.vocab_size = config.vocab_size @@ -54,6 +54,13 @@ def __init__(self, config: UltravoxConfig): self.multi_modal_projector = UltravoxProjector(config) self.language_model = self._create_language_model(config) + # Determine no_split_modules dynamically to use with FSDP auto_wrap policy. + # FSDP throws an error if some of the layer types are not found in the model. + # This would be something like ["LlamaDecoderLayer", "WhisperEncoderLayer"] + self._no_split_modules = (self.language_model._no_split_modules or []) + ( + self.audio_tower._no_split_modules or [] + ) + self.loss_config = LossConfig() self.post_init() @@ -356,9 +363,13 @@ def push_to_hub(self, *args, **kwargs): self.to(self.language_model.dtype) return super().push_to_hub(*args, **kwargs) - def state_dict(self, *args, **kwargs): + def save_pretrained( + self, *args, state_dict: Optional[Dict[str, Any]] = None, **kwargs + ): + if state_dict is None: + state_dict = super().state_dict() + named_params = dict(self.named_parameters()) - state_dict = super().state_dict(*args, **kwargs) state_dict = { k: v @@ -366,16 +377,11 @@ def state_dict(self, *args, **kwargs): if k in self.keep_params or (k in named_params and named_params[k].requires_grad) } - return state_dict - def load_state_dict( - self, - state_dict: Dict[str, Any], - *args, - **kwargs, - ): + super().save_pretrained(*args, state_dict=state_dict, **kwargs) + + def _pre_load_state_dict_hook(self, state_dict: Dict[str, Any], *args, **kwargs): self.keep_params.update(set(state_dict.keys())) - return super().load_state_dict(state_dict, *args, **kwargs) def print_trainable_parameters(self): """ @@ -510,6 +516,7 @@ class ModifiedWhisperEncoder(whisper.WhisperEncoder): """ base_model_prefix = "model.encoder" + _no_split_modules = ["WhisperEncoderLayer"] def forward( self, diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 05cea992..c4a4b43a 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -317,12 +317,22 @@ def train(args: config_base.TrainConfig): logging.info(f"train end time: {t_end}") logging.info(f"elapsed: {t_end - t_start}") - if is_master: - # Saving the model using pipeline to ensure its code is saved - pipeline = ultravox_pipeline.UltravoxPipeline( - model, tokenizer=text_tokenizer, device=device - ) - pipeline.save_pretrained(args.output_dir) + # We use both pipeline.save_pretrained and trainer.save_model to save everything. + # This is because pipeline.save_pretrained knows how to save the pipeline (code and config), + # but it doesn't know how to save FSDP models correctly (the final tensors could be flattened). + # on the other hand, trainer.save_model knows how to save FSDP models correctly, but it won't save the pipeline. + # Saving FSDP models is already quite slow though, so we don't want to save the model twice. + pipeline = ultravox_pipeline.UltravoxPipeline( + model, tokenizer=text_tokenizer, device=model.device + ) + old_save_pretrained = model.save_pretrained + model.save_pretrained = lambda *_, **__: None # type: ignore[method-assign] + # saves the pipeline code and populates the config + pipeline.save_pretrained(args.output_dir) + model.save_pretrained = old_save_pretrained # type: ignore[method-assign] + + # saves the model weights correctly (FSDP or otherwise) + trainer.save_model(args.output_dir) def evaluate(args: config_base.TrainConfig):