From 63e266c20b855f645ae62f23398c72119a20d156 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 24 Oct 2022 14:43:18 +0800 Subject: [PATCH] fix bug for multiple prompts inputs --- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index abcc7fba6e8a..2b885361e59b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -340,8 +340,8 @@ def __call__( masked_image_latents = 0.18215 * masked_image_latents # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - mask = mask.repeat(num_images_per_prompt, 1, 1, 1) - masked_image_latents = masked_image_latents.repeat(num_images_per_prompt, 1, 1, 1) + mask = mask.repeat(batch_size * num_images_per_prompt, 1, 1, 1) + masked_image_latents = masked_image_latents.repeat(batch_size * num_images_per_prompt, 1, 1, 1) mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask masked_image_latents = (