From c4d282360184686f7d4a66787a4be9898cf1b8b0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 26 Aug 2023 23:31:56 +0200 Subject: [PATCH] [SDXL Lora] Fix last ben sdxl lora (#4797) * Fix last ben sdxl lora * Correct typo * make style --- src/diffusers/loaders.py | 41 +++++++++++++++++++++----------- tests/models/test_lora_layers.py | 16 +++++++++++++ 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index b5fb29758b9b..2acfdc594ff8 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1084,7 +1084,7 @@ def lora_state_dict( # Map SDXL blocks correctly. if unet_config is not None: # use unet config to remap block numbers - state_dict = cls._map_sgm_blocks_to_diffusers(state_dict, unet_config) + state_dict = cls._maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict) return state_dict, network_alphas @@ -1121,24 +1121,41 @@ def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_ext return weight_name @classmethod - def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5): - is_all_unet = all(k.startswith("lora_unet") for k in state_dict) + def _maybe_map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5): + # 1. get all state_dict_keys + all_keys = state_dict.keys() + sgm_patterns = ["input_blocks", "middle_block", "output_blocks"] + + # 2. check if needs remapping, if not return original dict + is_in_sgm_format = False + for key in all_keys: + if any(p in key for p in sgm_patterns): + is_in_sgm_format = True + break + + if not is_in_sgm_format: + return state_dict + + # 3. Else remap from SGM patterns new_state_dict = {} inner_block_map = ["resnets", "attentions", "upsamplers"] # Retrieves # of down, mid and up blocks input_block_ids, middle_block_ids, output_block_ids = set(), set(), set() - for layer in state_dict: - if "text" not in layer: + + for layer in all_keys: + if "text" in layer: + new_state_dict[layer] = state_dict.pop(layer) + else: layer_id = int(layer.split(delimiter)[:block_slice_pos][-1]) - if "input_blocks" in layer: + if sgm_patterns[0] in layer: input_block_ids.add(layer_id) - elif "middle_block" in layer: + elif sgm_patterns[1] in layer: middle_block_ids.add(layer_id) - elif "output_blocks" in layer: + elif sgm_patterns[2] in layer: output_block_ids.add(layer_id) else: - raise ValueError("Checkpoint not supported") + raise ValueError(f"Checkpoint not supported because layer {layer} not supported.") input_blocks = { layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key] @@ -1201,12 +1218,8 @@ def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", bl ) new_state_dict[new_key] = state_dict.pop(key) - if is_all_unet and len(state_dict) > 0: + if len(state_dict) > 0: raise ValueError("At this point all state dict entries have to be converted.") - else: - # Remaining is the text encoder state dict. - for k, v in state_dict.items(): - new_state_dict.update({k: v}) return new_state_dict diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index c2fe98993d00..cbede0d124ee 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -942,3 +942,19 @@ def test_sdxl_1_0_lora(self): expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535]) self.assertTrue(np.allclose(images, expected, atol=1e-4)) + + def test_sdxl_1_0_last_ben(self): + generator = torch.Generator().manual_seed(0) + + pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0") + pipe.enable_model_cpu_offload() + lora_model_id = "TheLastBen/Papercut_SDXL" + lora_filename = "papercut.safetensors" + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + + images = pipe("papercut.safetensors", output_type="np", generator=generator, num_inference_steps=2).images + + images = images[0, -3:, -3:, -1].flatten() + expected = np.array([0.5244, 0.4347, 0.4312, 0.4246, 0.4398, 0.4409, 0.4884, 0.4938, 0.4094]) + + self.assertTrue(np.allclose(images, expected, atol=1e-3))