Skip to content

Commit

Permalink
Flux-LoRa first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
AI-Casanova committed Sep 27, 2024
1 parent bcb704a commit 314e333
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 15 deletions.
2 changes: 1 addition & 1 deletion extensions-builtin/Lora/network_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def create_module(self, weights, key, none_ok=False):
if weight is None and none_ok:
return None
linear_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention, diffusers_lora.LoRACompatibleLinear]
is_linear = type(self.sd_module) in linear_modules or self.sd_module.__class__.__name__ in {"NNCFLinear", "QLinear"}
is_linear = type(self.sd_module) in linear_modules or self.sd_module.__class__.__name__ in {"NNCFLinear", "QLinear", "Linear4bit"}
is_conv = type(self.sd_module) in [torch.nn.Conv2d, diffusers_lora.LoRACompatibleConv] or self.sd_module.__class__.__name__ in {"NNCFConv2d", "QConv2d"}
if is_linear:
weight = weight.reshape(weight.shape[0], -1)
Expand Down
1 change: 0 additions & 1 deletion extensions-builtin/Lora/network_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
'kandinsky',
'hunyuandit',
'auraflow',
'f1',
]

force_classes = [ # forced always
Expand Down
32 changes: 19 additions & 13 deletions extensions-builtin/Lora/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,29 @@
def assign_network_names_to_compvis_modules(sd_model):
network_layer_mapping = {}
if shared.native:
if not hasattr(shared.sd_model, 'text_encoder') or not hasattr(shared.sd_model, 'unet'):
sd_model.network_layer_mapping = {}
return
for name, module in shared.sd_model.text_encoder.named_modules():
prefix = "lora_te1_" if shared.sd_model_type == "sdxl" else "lora_te_"
network_name = prefix + name.replace(".", "_")
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
if shared.sd_model_type == "sdxl":
if hasattr(shared.sd_model, 'text_encoder'):
for name, module in shared.sd_model.text_encoder.named_modules():
prefix = "lora_te1_" if hasattr(shared.sd_model, 'text_encoder_2') else "lora_te_"
network_name = prefix + name.replace(".", "_")
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
if hasattr(shared.sd_model, 'text_encoder_2'):
for name, module in shared.sd_model.text_encoder_2.named_modules():
network_name = "lora_te2_" + name.replace(".", "_")
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
for name, module in shared.sd_model.unet.named_modules():
network_name = "lora_unet_" + name.replace(".", "_")
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
if hasattr(shared.sd_model, 'unet'):
for name, module in shared.sd_model.unet.named_modules():
network_name = "lora_unet_" + name.replace(".", "_")
network_layer_mapping[network_name] = module
module.network_layer_name = network_name
if hasattr(shared.sd_model, 'transformer'):
for name, module in shared.sd_model.transformer.named_modules():
network_name = "lora_transformer_" + name.replace(".", "_")
network_layer_mapping[network_name] = module
if "norm" in network_name and "linear" not in network_name:
continue
module.network_layer_name = network_name
else:
if not hasattr(shared.sd_model, 'cond_stage_model'):
sd_model.network_layer_mapping = {}
Expand Down

0 comments on commit 314e333

Please sign in to comment.