Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dreambooth SDXL LoRA - mat1 and mat2 shapes cannot be multiplied (2x2048 and 2816x1280) #7239

Closed
shawnrushefsky opened this issue Mar 6, 2024 · 5 comments · Fixed by #7242
Labels
bug Something isn't working training

Comments

@shawnrushefsky
Copy link

shawnrushefsky commented Mar 6, 2024

Describe the bug

Trying to run an sdxl lora dreambooth training job with prior preservation. After class images are generated, it dies with the logged error, mat1 and mat2 shapes cannot be multiplied (2x2048 and 2816x1280). Instance images are 37 photos of my dog in a variety of sizes

Reproduction

accelerate launch train_dreambooth_lora_sdxl.py \
  --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
  --instance_data_dir=/instance_images \
  --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
  --output_dir=/output \
  --instance_prompt="timber" \
  --mixed_precision=fp16 \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --learning_rate=1e-05 \
  --lr_scheduler=constant \
  --lr_warmup_steps=0 \
  --checkpointing_steps=100 \
  --seed=0 \
  --resume_from_checkpoint=latest \
  --checkpoints_total_limit=1 \
  --max_train_steps=1800 \
  --train_text_encoder \
  --text_encoder_lr=5e-06 \
  --with_prior_preservation \
  --class_data_dir=/class_images \
  --class_prompt="photo of a dog" \
  --num_class_images=25 \
  --validation_prompt="timber as an ace space pilot, detailed illustration" \
  --validation_epochs=10 \
  --report_to=wandb \
  --sample_batch_size=4

Logs

Steps:   0%|          | 0/1800 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/app/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py", line 1793, in <module>
    main(args)
  File "/app/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py", line 1572, in main
    model_pred = unet(
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 817, in forward
    return model_forward(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py", line 805, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/opt/conda/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
File "/app/diffusers/src/diffusers/models/unets/unet_2d_condition.py", line 1162, in forward
    aug_emb = self.get_aug_embed(
  File "/app/diffusers/src/diffusers/models/unets/unet_2d_condition.py", line 987, in get_aug_embed
    aug_emb = self.add_embedding(add_embeds)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/app/diffusers/src/diffusers/models/embeddings.py", line 228, in forward
    sample = self.linear_1(sample)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x2048 and 2816x1280)

System Info

  • diffusers version: 0.27.0.dev0
  • Platform: Linux-5.15.146.1-microsoft-standard-WSL2-x86_64-with-glibc2.35
  • Python version: 3.10.13
  • PyTorch version (GPU?): 2.2.0 (True)
  • Huggingface_hub version: 0.21.3
  • Transformers version: 4.38.2
  • Accelerate version: 0.27.0
  • xFormers version: 0.0.24
  • Using GPU in script?: yes, RTX 4090
  • Using distributed or parallel set-up in script?: no

Who can help?

@sayakpaul

@shawnrushefsky shawnrushefsky added the bug Something isn't working label Mar 6, 2024
@shawnrushefsky shawnrushefsky changed the title Dreambooth SDXL LoRA Dreambooth SDXL LoRA - mat1 and mat2 shapes cannot be multiplied (2x2048 and 2816x1280) Mar 6, 2024
@sayakpaul
Copy link
Member

Cc: @linoytsaban have you faced this?

@linoytsaban
Copy link
Collaborator

@sayakpaul Oh I think it's related to the micro-conditioning!

@sayakpaul
Copy link
Member

Have fixed it with #7242 :)

@yjdqk
Copy link

yjdqk commented Jan 14, 2025

@shawnrushefsky ,hello,I met the same problem with you and still cannot solve it .what does "--num_class_images" mean ? How can I solve the problem

@yjdqk
Copy link

yjdqk commented Jan 14, 2025

describe Error info: RuntimeError: mat1 and mat2 shapes cannot be multiplied (771536 and 768320)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants