diff --git a/optimum/habana/accelerate/utils/transformer_engine.py b/optimum/habana/accelerate/utils/transformer_engine.py index df0a07b5c4..89e1a895aa 100755 --- a/optimum/habana/accelerate/utils/transformer_engine.py +++ b/optimum/habana/accelerate/utils/transformer_engine.py @@ -16,6 +16,13 @@ import functools import torch +from transformers.utils import ( + is_peft_available, +) + + +if is_peft_available(): + from peft.tuners import lora has_transformer_engine = False @@ -47,7 +54,30 @@ def _convert_model(model, to_transformer_engine=True, _convert_linear=True, _min if not is_fp8_available(): raise ImportError("Using `convert_model` requires transformer_engine to be installed.") for name, module in model.named_children(): - if isinstance(module, torch.nn.Linear) and to_transformer_engine and _convert_linear: + if is_peft_available() and isinstance(module, lora.Linear) and to_transformer_engine and _convert_linear: + # For lora linear module, convert only base linear layer to fp8 and skip lora-a, + # lora-b linear layers. Since lora-a, lora-b are small in size, there is not much + # device performance gain by pushing these in fp8. This way we avoid host overhead + # associated with using TE for these layers. + for name, lora_module in module.named_children(): + if name == "base_layer": + has_bias = lora_module.bias is not None + # Initializing TE linear without weights and biases and shallow copying them from the original module. + te_module = te.Linear( + lora_module.in_features, + lora_module.out_features, + bias=has_bias, + params_dtype=lora_module.weight.dtype, + skip_weight_param_allocation=True, + minimize_memory=_minimize_memory, + ) + te_module.weight = lora_module.weight + + if has_bias: + te_module.bias = lora_module.bias + + setattr(module, name, te_module) + elif isinstance(module, torch.nn.Linear) and to_transformer_engine and _convert_linear: has_bias = module.bias is not None # Initializing TE linear without weights and biases and shallow copying them from the original module. te_module = te.Linear(