Skip to content

Commit

Permalink
[Peft] fix saving / loading when unet is not "unet" (#6046)
Browse files Browse the repository at this point in the history
* [Peft] fix saving / loading when unet is not "unet"

* Update src/diffusers/loaders/lora.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* undo stablediffusion-xl changes

* use unet_name to get unet for lora helpers

* use unet_name

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
  • Loading branch information
kashif and sayakpaul committed Dec 26, 2023
1 parent 404351f commit 4c7e983
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 20 deletions.
6 changes: 4 additions & 2 deletions src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,11 @@ def load_ip_adapter(
self.feature_extractor = CLIPImageProcessor()

# load ip-adapter into unet
self.unet._load_ip_adapter_weights(state_dict)
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet._load_ip_adapter_weights(state_dict)

def set_ip_adapter_scale(self, scale):
for attn_processor in self.unet.attn_processors.values():
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
for attn_processor in unet.attn_processors.values():
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
attn_processor.scale = scale
46 changes: 28 additions & 18 deletions src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,10 +912,10 @@ def pack_weights(layers, prefix):
)

if unet_lora_layers:
state_dict.update(pack_weights(unet_lora_layers, "unet"))
state_dict.update(pack_weights(unet_lora_layers, cls.unet_name))

if text_encoder_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
state_dict.update(pack_weights(text_encoder_lora_layers, cls.text_encoder_name))

if transformer_lora_layers:
state_dict.update(pack_weights(transformer_lora_layers, "transformer"))
Expand Down Expand Up @@ -975,20 +975,22 @@ def unload_lora_weights(self):
>>> ...
```
"""
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet

if not USE_PEFT_BACKEND:
if version.parse(__version__) > version.parse("0.23"):
logger.warn(
"You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,"
"you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
)

for _, module in self.unet.named_modules():
for _, module in unet.named_modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)
else:
recurse_remove_peft_layers(self.unet)
if hasattr(self.unet, "peft_config"):
del self.unet.peft_config
recurse_remove_peft_layers(unet)
if hasattr(unet, "peft_config"):
del unet.peft_config

# Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch()
Expand Down Expand Up @@ -1027,7 +1029,8 @@ def fuse_lora(
)

if fuse_unet:
self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)

if USE_PEFT_BACKEND:
from peft.tuners.tuners_utils import BaseTunerLayer
Expand Down Expand Up @@ -1080,13 +1083,14 @@ def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
if unfuse_unet:
if not USE_PEFT_BACKEND:
self.unet.unfuse_lora()
unet.unfuse_lora()
else:
from peft.tuners.tuners_utils import BaseTunerLayer

for module in self.unet.modules():
for module in unet.modules():
if isinstance(module, BaseTunerLayer):
module.unmerge()

Expand Down Expand Up @@ -1202,8 +1206,9 @@ def set_adapters(
adapter_names: Union[List[str], str],
adapter_weights: Optional[List[float]] = None,
):
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
# Handle the UNET
self.unet.set_adapters(adapter_names, adapter_weights)
unet.set_adapters(adapter_names, adapter_weights)

# Handle the Text Encoder
if hasattr(self, "text_encoder"):
Expand All @@ -1216,7 +1221,8 @@ def disable_lora(self):
raise ValueError("PEFT backend is required for this method.")

# Disable unet adapters
self.unet.disable_lora()
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet.disable_lora()

# Disable text encoder adapters
if hasattr(self, "text_encoder"):
Expand All @@ -1229,7 +1235,8 @@ def enable_lora(self):
raise ValueError("PEFT backend is required for this method.")

# Enable unet adapters
self.unet.enable_lora()
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet.enable_lora()

# Enable text encoder adapters
if hasattr(self, "text_encoder"):
Expand All @@ -1251,7 +1258,8 @@ def delete_adapters(self, adapter_names: Union[List[str], str]):
adapter_names = [adapter_names]

# Delete unet adapters
self.unet.delete_adapters(adapter_names)
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet.delete_adapters(adapter_names)

for adapter_name in adapter_names:
# Delete text encoder adapters
Expand Down Expand Up @@ -1284,8 +1292,8 @@ def get_active_adapters(self) -> List[str]:
from peft.tuners.tuners_utils import BaseTunerLayer

active_adapters = []

for module in self.unet.modules():
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
for module in unet.modules():
if isinstance(module, BaseTunerLayer):
active_adapters = module.active_adapters
break
Expand All @@ -1309,8 +1317,9 @@ def get_list_adapters(self) -> Dict[str, List[str]]:
if hasattr(self, "text_encoder_2") and hasattr(self.text_encoder_2, "peft_config"):
set_adapters["text_encoder_2"] = list(self.text_encoder_2.peft_config.keys())

if hasattr(self, "unet") and hasattr(self.unet, "peft_config"):
set_adapters["unet"] = list(self.unet.peft_config.keys())
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
if hasattr(self, self.unet_name) and hasattr(unet, "peft_config"):
set_adapters[self.unet_name] = list(self.unet.peft_config.keys())

return set_adapters

Expand All @@ -1331,7 +1340,8 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
from peft.tuners.tuners_utils import BaseTunerLayer

# Handle the UNET
for unet_module in self.unet.modules():
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
for unet_module in unet.modules():
if isinstance(unet_module, BaseTunerLayer):
for adapter_name in adapter_names:
unet_module.lora_A[adapter_name].to(device)
Expand Down

0 comments on commit 4c7e983

Please sign in to comment.