diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py index 7dfded536..da4f6645f 100644 --- a/extensions-builtin/Lora/network_lora.py +++ b/extensions-builtin/Lora/network_lora.py @@ -27,7 +27,7 @@ def create_module(self, weights, key, none_ok=False): if weight is None and none_ok: return None linear_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, diffusers_lora.LoRACompatibleLinear] - is_linear = type(self.sd_module) in linear_modules or self.sd_module.__class__.__name__ in {"NNCFLinear", "QLinear"} + is_linear = type(self.sd_module) in linear_modules or self.sd_module.__class__.__name__ in {"NNCFLinear", "QLinear", "Linear4bit"} is_conv = type(self.sd_module) in [torch.nn.Conv2d, diffusers_lora.LoRACompatibleConv] or self.sd_module.__class__.__name__ in {"NNCFConv2d", "QConv2d"} if is_linear: weight = weight.reshape(weight.shape[0], -1) diff --git a/extensions-builtin/Lora/network_overrides.py b/extensions-builtin/Lora/network_overrides.py index 9123d0039..24afb0c28 100644 --- a/extensions-builtin/Lora/network_overrides.py +++ b/extensions-builtin/Lora/network_overrides.py @@ -30,7 +30,6 @@ 'kandinsky', 'hunyuandit', 'auraflow', - 'f1', ] force_classes = [ # forced always diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 6e20b5a59..fdbfd3401 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -50,23 +50,29 @@ def assign_network_names_to_compvis_modules(sd_model): network_layer_mapping = {} if shared.native: - if not hasattr(shared.sd_model, 'text_encoder') or not hasattr(shared.sd_model, 'unet'): - sd_model.network_layer_mapping = {} - return - for name, module in shared.sd_model.text_encoder.named_modules(): - prefix = "lora_te1_" if shared.sd_model_type == "sdxl" else "lora_te_" - network_name = prefix + name.replace(".", "_") - network_layer_mapping[network_name] = module - module.network_layer_name = network_name - if shared.sd_model_type == "sdxl": + if hasattr(shared.sd_model, 'text_encoder'): + for name, module in shared.sd_model.text_encoder.named_modules(): + prefix = "lora_te1_" if hasattr(shared.sd_model, 'text_encoder_2') else "lora_te_" + network_name = prefix + name.replace(".", "_") + network_layer_mapping[network_name] = module + module.network_layer_name = network_name + if hasattr(shared.sd_model, 'text_encoder_2'): for name, module in shared.sd_model.text_encoder_2.named_modules(): network_name = "lora_te2_" + name.replace(".", "_") network_layer_mapping[network_name] = module module.network_layer_name = network_name - for name, module in shared.sd_model.unet.named_modules(): - network_name = "lora_unet_" + name.replace(".", "_") - network_layer_mapping[network_name] = module - module.network_layer_name = network_name + if hasattr(shared.sd_model, 'unet'): + for name, module in shared.sd_model.unet.named_modules(): + network_name = "lora_unet_" + name.replace(".", "_") + network_layer_mapping[network_name] = module + module.network_layer_name = network_name + if hasattr(shared.sd_model, 'transformer'): + for name, module in shared.sd_model.transformer.named_modules(): + network_name = "lora_transformer_" + name.replace(".", "_") + network_layer_mapping[network_name] = module + if "norm" in network_name and "linear" not in network_name: + continue + module.network_layer_name = network_name else: if not hasattr(shared.sd_model, 'cond_stage_model'): sd_model.network_layer_mapping = {}