diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 85effa19b59d..849becae4a61 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1178,7 +1178,7 @@ def load_model_hook(models, input_dir): _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) _set_state_dict_into_text_encoder( - lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_one_ + lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_ ) # Make sure the trainable params are in float32. This is again needed since the base models