diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index 08fb914ccc..8efde30a57 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -986,7 +986,13 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args, use_reentrant: Optio inputs["flash_attention_causal_mask"] = True if self.model.config is not None: if self.model.config.model_type in ["llama", "qwen2", "mistral", "starcoder2"]: - inputs["lazy_mode"] = args.use_lazy_mode + if _is_peft_model(model): + forward_method = getattr(model.get_base_model(), "forward") + else: + forward_method = getattr(model, "forward") + signature = inspect.signature(forward_method) + if "lazy_mode" in signature.parameters: + inputs["lazy_mode"] = args.use_lazy_mode # TODO: keep syncs for fast DDP? with self.accelerator.accumulate(model): tr_loss_step = self.training_step(model, inputs)