Skip to content

Commit

Permalink
Update transformer_engine._convert_model to skip LoRA layers (#1766)
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekgoe authored Feb 18, 2025
1 parent d80283e commit 21a5495
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion optimum/habana/accelerate/utils/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
import functools

import torch
from transformers.utils import (
is_peft_available,
)


if is_peft_available():
from peft.tuners import lora


has_transformer_engine = False
Expand Down Expand Up @@ -47,7 +54,30 @@ def _convert_model(model, to_transformer_engine=True, _convert_linear=True, _min
if not is_fp8_available():
raise ImportError("Using `convert_model` requires transformer_engine to be installed.")
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear) and to_transformer_engine and _convert_linear:
if is_peft_available() and isinstance(module, lora.Linear) and to_transformer_engine and _convert_linear:
# For lora linear module, convert only base linear layer to fp8 and skip lora-a,
# lora-b linear layers. Since lora-a, lora-b are small in size, there is not much
# device performance gain by pushing these in fp8. This way we avoid host overhead
# associated with using TE for these layers.
for name, lora_module in module.named_children():
if name == "base_layer":
has_bias = lora_module.bias is not None
# Initializing TE linear without weights and biases and shallow copying them from the original module.
te_module = te.Linear(
lora_module.in_features,
lora_module.out_features,
bias=has_bias,
params_dtype=lora_module.weight.dtype,
skip_weight_param_allocation=True,
minimize_memory=_minimize_memory,
)
te_module.weight = lora_module.weight

if has_bias:
te_module.bias = lora_module.bias

setattr(module, name, te_module)
elif isinstance(module, torch.nn.Linear) and to_transformer_engine and _convert_linear:
has_bias = module.bias is not None
# Initializing TE linear without weights and biases and shallow copying them from the original module.
te_module = te.Linear(
Expand Down

0 comments on commit 21a5495

Please sign in to comment.