diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index e98d4ad4e37b..ce2f6585c601 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -400,15 +400,22 @@ def prepare_ip_adapter_image_embeds( image_embeds.append(single_image_embeds) else: + repeat_dims = [1] image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) else: - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) image_embeds.append(single_image_embeds) return image_embeds @@ -509,9 +516,9 @@ def check_inputs( raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) - elif ip_adapter_image_embeds[0].ndim != 3: + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( - f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index ced64889044f..bcfcd3a24b5d 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -478,15 +478,22 @@ def prepare_ip_adapter_image_embeds( image_embeds.append(single_image_embeds) else: + repeat_dims = [1] image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) else: - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) image_embeds.append(single_image_embeds) return image_embeds @@ -589,9 +596,9 @@ def check_inputs( raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) - elif ip_adapter_image_embeds[0].ndim != 3: + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( - f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) def get_timesteps(self, num_inference_steps, timesteps, strength, device): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index eab2f7aa22d0..8f31dfc2678a 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -510,15 +510,22 @@ def prepare_ip_adapter_image_embeds( image_embeds.append(single_image_embeds) else: + repeat_dims = [1] image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) else: - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) image_embeds.append(single_image_embeds) return image_embeds @@ -726,9 +733,9 @@ def check_inputs( raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) - elif ip_adapter_image_embeds[0].ndim != 3: + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( - f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) def check_image(self, image, prompt, prompt_embeds): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 4fc9791d3d8e..9d2c76fd7483 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -503,15 +503,22 @@ def prepare_ip_adapter_image_embeds( image_embeds.append(single_image_embeds) else: + repeat_dims = [1] image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) else: - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) image_embeds.append(single_image_embeds) return image_embeds @@ -713,9 +720,9 @@ def check_inputs( raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) - elif ip_adapter_image_embeds[0].ndim != 3: + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( - f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index ce7537d84215..c4f1bff5efcd 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -628,15 +628,22 @@ def prepare_ip_adapter_image_embeds( image_embeds.append(single_image_embeds) else: + repeat_dims = [1] image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) else: - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) image_embeds.append(single_image_embeds) return image_embeds @@ -871,9 +878,9 @@ def check_inputs( raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) - elif ip_adapter_image_embeds[0].ndim != 3: + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( - f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 0f1d5ea48e71..52ffe5a3f356 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -537,15 +537,22 @@ def prepare_ip_adapter_image_embeds( image_embeds.append(single_image_embeds) else: + repeat_dims = [1] image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) else: - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) image_embeds.append(single_image_embeds) return image_embeds @@ -817,9 +824,9 @@ def check_inputs( raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) - elif ip_adapter_image_embeds[0].ndim != 3: + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( - f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) def prepare_control_image( diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 9883b4f64790..0b611350a6f1 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -515,15 +515,22 @@ def prepare_ip_adapter_image_embeds( image_embeds.append(single_image_embeds) else: + repeat_dims = [1] image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) else: - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) image_embeds.append(single_image_embeds) return image_embeds @@ -730,9 +737,9 @@ def check_inputs( raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) - elif ip_adapter_image_embeds[0].ndim != 3: + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( - f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index cf32ae81c562..4deee37f7df1 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -567,15 +567,22 @@ def prepare_ip_adapter_image_embeds( image_embeds.append(single_image_embeds) else: + repeat_dims = [1] image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) else: - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) image_embeds.append(single_image_embeds) return image_embeds @@ -794,9 +801,9 @@ def check_inputs( raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) - elif ip_adapter_image_embeds[0].ndim != 3: + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( - f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py index 6d1b1a0db444..f64854ea982b 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py @@ -453,15 +453,22 @@ def prepare_ip_adapter_image_embeds( image_embeds.append(single_image_embeds) else: + repeat_dims = [1] image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) else: - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) image_embeds.append(single_image_embeds) return image_embeds @@ -647,9 +654,9 @@ def check_inputs( raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) - elif ip_adapter_image_embeds[0].ndim != 3: + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( - f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) @property diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py index fa27c0fbd5bc..e9bacaa89ba5 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py @@ -437,15 +437,22 @@ def prepare_ip_adapter_image_embeds( image_embeds.append(single_image_embeds) else: + repeat_dims = [1] image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) else: - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) image_embeds.append(single_image_embeds) return image_embeds @@ -579,9 +586,9 @@ def check_inputs( raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) - elif ip_adapter_image_embeds[0].ndim != 3: + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( - f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) @property diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index 485ccb22e5e9..bd3e2891f0d6 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -582,9 +582,9 @@ def check_inputs( raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) - elif ip_adapter_image_embeds[0].ndim != 3: + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( - f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds @@ -619,15 +619,22 @@ def prepare_ip_adapter_image_embeds( image_embeds.append(single_image_embeds) else: + repeat_dims = [1] image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) else: - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) image_embeds.append(single_image_embeds) return image_embeds diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 5126e6f4c378..9e4e6c186ffa 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -520,15 +520,22 @@ def prepare_ip_adapter_image_embeds( image_embeds.append(single_image_embeds) else: + repeat_dims = [1] image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) else: - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) image_embeds.append(single_image_embeds) return image_embeds @@ -639,9 +646,9 @@ def check_inputs( raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) - elif ip_adapter_image_embeds[0].ndim != 3: + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( - f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 9c6fbb2310ac..b43e0eb2abcd 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -564,15 +564,22 @@ def prepare_ip_adapter_image_embeds( image_embeds.append(single_image_embeds) else: + repeat_dims = [1] image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) else: - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) image_embeds.append(single_image_embeds) return image_embeds @@ -685,9 +692,9 @@ def check_inputs( raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) - elif ip_adapter_image_embeds[0].ndim != 3: + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( - f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) def get_timesteps(self, num_inference_steps, strength, device): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 3392fd6ddecc..221d5c2cfd3f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -636,15 +636,22 @@ def prepare_ip_adapter_image_embeds( image_embeds.append(single_image_embeds) else: + repeat_dims = [1] image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) else: - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) image_embeds.append(single_image_embeds) return image_embeds @@ -767,9 +774,9 @@ def check_inputs( raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) - elif ip_adapter_image_embeds[0].ndim != 3: + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( - f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) def prepare_latents( diff --git a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py index 502cd340bcd8..dbfb5e08ef23 100644 --- a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +++ b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py @@ -442,15 +442,22 @@ def prepare_ip_adapter_image_embeds( image_embeds.append(single_image_embeds) else: + repeat_dims = [1] image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) else: - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) image_embeds.append(single_image_embeds) return image_embeds @@ -553,9 +560,9 @@ def check_inputs( raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) - elif ip_adapter_image_embeds[0].ndim != 3: + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( - f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): diff --git a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py index 49cc68926b7e..feda710e0049 100644 --- a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py @@ -414,15 +414,22 @@ def prepare_ip_adapter_image_embeds( image_embeds.append(single_image_embeds) else: + repeat_dims = [1] image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) else: - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) image_embeds.append(single_image_embeds) return image_embeds @@ -550,9 +557,9 @@ def check_inputs( raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) - elif ip_adapter_image_embeds[0].ndim != 3: + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( - f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 4a34ae89d245..776696e9d486 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -549,15 +549,22 @@ def prepare_ip_adapter_image_embeds( image_embeds.append(single_image_embeds) else: + repeat_dims = [1] image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) else: - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) image_embeds.append(single_image_embeds) return image_embeds @@ -671,9 +678,9 @@ def check_inputs( raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) - elif ip_adapter_image_embeds[0].ndim != 3: + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( - f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index ef25ca94d16c..5ba12baad065 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -616,9 +616,9 @@ def check_inputs( raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) - elif ip_adapter_image_embeds[0].ndim != 3: + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( - f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): @@ -782,15 +782,22 @@ def prepare_ip_adapter_image_embeds( image_embeds.append(single_image_embeds) else: + repeat_dims = [1] image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) else: - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) image_embeds.append(single_image_embeds) return image_embeds diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 2cb946eb56ad..5b9628f51a41 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -486,15 +486,22 @@ def prepare_ip_adapter_image_embeds( image_embeds.append(single_image_embeds) else: + repeat_dims = [1] image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) else: - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) image_embeds.append(single_image_embeds) return image_embeds @@ -851,9 +858,9 @@ def check_inputs( raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) - elif ip_adapter_image_embeds[0].ndim != 3: + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( - f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) def prepare_latents( diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index abf743f4f305..4e0cc61f5c1d 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -563,15 +563,22 @@ def prepare_ip_adapter_image_embeds( image_embeds.append(single_image_embeds) else: + repeat_dims = [1] image_embeds = [] for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) - single_negative_image_embeds = single_negative_image_embeds.repeat(num_images_per_prompt, 1, 1) - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) else: - single_image_embeds = single_image_embeds.repeat(num_images_per_prompt, 1, 1) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) image_embeds.append(single_image_embeds) return image_embeds @@ -686,9 +693,9 @@ def check_inputs( raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) - elif ip_adapter_image_embeds[0].ndim != 3: + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( - f"`ip_adapter_image_embeds` has to be a list of 3D tensors but is {ip_adapter_image_embeds[0].ndim}D" + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents