Skip to content

Commit

Permalink
[SDXL Lora] Fix last ben sdxl lora (#4797)
Browse files Browse the repository at this point in the history
* Fix last ben sdxl lora

* Correct typo

* make style
  • Loading branch information
patrickvonplaten authored Aug 26, 2023
1 parent 4f8853e commit c4d2823
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 14 deletions.
41 changes: 27 additions & 14 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,7 @@ def lora_state_dict(
# Map SDXL blocks correctly.
if unet_config is not None:
# use unet config to remap block numbers
state_dict = cls._map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict = cls._maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict)

return state_dict, network_alphas
Expand Down Expand Up @@ -1121,24 +1121,41 @@ def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_ext
return weight_name

@classmethod
def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5):
is_all_unet = all(k.startswith("lora_unet") for k in state_dict)
def _maybe_map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5):
# 1. get all state_dict_keys
all_keys = state_dict.keys()
sgm_patterns = ["input_blocks", "middle_block", "output_blocks"]

# 2. check if needs remapping, if not return original dict
is_in_sgm_format = False
for key in all_keys:
if any(p in key for p in sgm_patterns):
is_in_sgm_format = True
break

if not is_in_sgm_format:
return state_dict

# 3. Else remap from SGM patterns
new_state_dict = {}
inner_block_map = ["resnets", "attentions", "upsamplers"]

# Retrieves # of down, mid and up blocks
input_block_ids, middle_block_ids, output_block_ids = set(), set(), set()
for layer in state_dict:
if "text" not in layer:

for layer in all_keys:
if "text" in layer:
new_state_dict[layer] = state_dict.pop(layer)
else:
layer_id = int(layer.split(delimiter)[:block_slice_pos][-1])
if "input_blocks" in layer:
if sgm_patterns[0] in layer:
input_block_ids.add(layer_id)
elif "middle_block" in layer:
elif sgm_patterns[1] in layer:
middle_block_ids.add(layer_id)
elif "output_blocks" in layer:
elif sgm_patterns[2] in layer:
output_block_ids.add(layer_id)
else:
raise ValueError("Checkpoint not supported")
raise ValueError(f"Checkpoint not supported because layer {layer} not supported.")

input_blocks = {
layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key]
Expand Down Expand Up @@ -1201,12 +1218,8 @@ def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", bl
)
new_state_dict[new_key] = state_dict.pop(key)

if is_all_unet and len(state_dict) > 0:
if len(state_dict) > 0:
raise ValueError("At this point all state dict entries have to be converted.")
else:
# Remaining is the text encoder state dict.
for k, v in state_dict.items():
new_state_dict.update({k: v})

return new_state_dict

Expand Down
16 changes: 16 additions & 0 deletions tests/models/test_lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,3 +942,19 @@ def test_sdxl_1_0_lora(self):
expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535])

self.assertTrue(np.allclose(images, expected, atol=1e-4))

def test_sdxl_1_0_last_ben(self):
generator = torch.Generator().manual_seed(0)

pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe.enable_model_cpu_offload()
lora_model_id = "TheLastBen/Papercut_SDXL"
lora_filename = "papercut.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)

images = pipe("papercut.safetensors", output_type="np", generator=generator, num_inference_steps=2).images

images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.5244, 0.4347, 0.4312, 0.4246, 0.4398, 0.4409, 0.4884, 0.4938, 0.4094])

self.assertTrue(np.allclose(images, expected, atol=1e-3))

0 comments on commit c4d2823

Please sign in to comment.