Skip to content

Commit

Permalink
Add option, fix diffusers keys
Browse files Browse the repository at this point in the history
  • Loading branch information
AI-Casanova committed Sep 29, 2024
1 parent be7f86f commit bd5ac8e
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 10 deletions.
2 changes: 2 additions & 0 deletions extensions-builtin/Lora/lora_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ def diffusers(self, key):
if search_key.startswith(map_key):
key = key.replace(map_key, self.UNET_CONVERSION_MAP[map_key]).replace("oft", "lora") # pylint: disable=unsubscriptable-object
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
if sd_module is None:
sd_module = shared.sd_model.network_layer_mapping.get(key.replace("guidance", "timestep"), None) # FLUX1 fix
# SegMoE begin
expert_key = key + "_experts_0"
expert_module = shared.sd_model.network_layer_mapping.get(expert_key, None)
Expand Down
21 changes: 11 additions & 10 deletions extensions-builtin/Lora/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def load_network(name, network_on_disk) -> network.Network:
net = network.Network(name, network_on_disk)
net.mtime = os.path.getmtime(network_on_disk.filename)
sd = sd_models.read_state_dict(network_on_disk.filename, what='network')
if shared.sd_model_type == 'f1':
sd = lora_convert._convert_kohya_flux_lora_to_diffusers(sd) or sd # if kohya flux lora, convert state_dict
if shared.sd_model_type == 'f1': # if kohya flux lora, convert state_dict
sd = lora_convert._convert_kohya_flux_lora_to_diffusers(sd) or sd # pylint: disable=protected-access
assign_network_names_to_compvis_modules(shared.sd_model) # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
keys_failed_to_match = {}
matched_networks = {}
Expand Down Expand Up @@ -296,11 +296,11 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
elif hasattr(self, "qweight") and hasattr(self, "freeze"):
self.weight = torch.nn.Parameter(weights_backup.to(self.weight.device, copy=True))
self.freeze()
elif getattr(self.weight, "quant_type", None) is not None:
elif getattr(self, "quant_type", None) is not None:
import bitsandbytes
device = self.weight.device
self.weight = bitsandbytes.nn.Params4bit(weights_backup, quant_state=self.weight.quant_state,
quant_type=self.weight.quant_type, blocksize=self.weight.blocksize)
self.weight = bitsandbytes.nn.Params4bit(weights_backup, quant_state=self.quant_state,
quant_type=self.quant_type, blocksize=self.blocksize)
self.weight.to(device)
else:
self.weight.copy_(weights_backup)
Expand Down Expand Up @@ -337,14 +337,16 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
if isinstance(self, torch.nn.MultiheadAttention):
weights_backup = (self.in_proj_weight.clone().to(devices.cpu), self.out_proj.weight.clone().to(devices.cpu))
elif getattr(self.weight, "quant_type", None) == "nf4" or getattr(self.weight, "quant_type", None) == "nf4":
# weights_backup = self.weight.__deepcopy__("")
import bitsandbytes
with devices.inference_context():
weights_backup = bitsandbytes.functional.dequantize_4bit(self.weight,
quant_state=self.weight.quant_state,
quant_type=self.weight.quant_type,
blocksize=self.weight.blocksize,
).to(devices.cpu)
self.quant_state = self.weight.quant_state
self.quant_type = self.weight.quant_type
self.blocksize = self.weight.blocksize
else:
weights_backup = self.weight.clone().to(devices.cpu)
self.network_weights_backup = weights_backup
Expand Down Expand Up @@ -378,10 +380,9 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
quant_state=self.weight.quant_state,
quant_type=self.weight.quant_type,
blocksize=self.weight.blocksize)
self.weight = bitsandbytes.nn.Params4bit(weight + updown,
quant_state=self.weight.quant_state,
quant_type=self.weight.quant_type,
blocksize=self.weight.blocksize)
self.weight = bitsandbytes.nn.Params4bit(weight + updown, quant_state=self.quant_state,
quant_type=shared.opts.lora_quant.lower(),
blocksize=self.blocksize)
self.weight.to(device)
else:
self.weight = torch.nn.Parameter(weight + updown)
Expand Down
1 change: 1 addition & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,7 @@ def get_default_modes():
"extra_networks_styles": OptionInfo(True, "Show built-in styles"),
"lora_preferred_name": OptionInfo("filename", "LoRA preferred name", gr.Radio, {"choices": ["filename", "alias"]}),
"lora_add_hashes_to_infotext": OptionInfo(False, "LoRA add hash info"),
"lora_quant": OptionInfo("FP4","LoRA precision for merged layers in quantized models", gr.Radio, {"choices": ["FP4", "NF4"]}),
"lora_force_diffusers": OptionInfo(False if not cmd_opts.use_openvino else True, "LoRA force loading of all models using Diffusers"),
"lora_maybe_diffusers": OptionInfo(False, "LoRA force loading of specific models using Diffusers"),
"lora_fuse_diffusers": OptionInfo(False if not cmd_opts.use_openvino else True, "LoRA use merge when using alternative method"),
Expand Down

0 comments on commit bd5ac8e

Please sign in to comment.