diff --git a/examples/community/pipeline_animatediff_controlnet.py b/examples/community/pipeline_animatediff_controlnet.py index 56dd856e1811..779477301157 100644 --- a/examples/community/pipeline_animatediff_controlnet.py +++ b/examples/community/pipeline_animatediff_controlnet.py @@ -13,7 +13,6 @@ # limitations under the License. import inspect -from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -27,6 +26,7 @@ from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel, UNetMotionModel from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.models.unets.unet_motion_model import MotionAdapter +from diffusers.pipelines.animatediff.pipeline_output import AnimateDiffPipelineOutput from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin from diffusers.schedulers import ( @@ -37,7 +37,7 @@ LMSDiscreteScheduler, PNDMScheduler, ) -from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers +from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers from diffusers.utils.torch_utils import is_compiled_module, randn_tensor @@ -91,10 +91,8 @@ """ +# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid def tensor2vid(video: torch.Tensor, processor, output_type="np"): - # Based on: - # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 - batch_size, channels, num_frames, height, width = video.shape outputs = [] for batch_idx in range(batch_size): @@ -103,12 +101,16 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): outputs.append(batch_output) - return outputs + if output_type == "np": + outputs = np.stack(outputs) + + elif output_type == "pt": + outputs = torch.stack(outputs) + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") -@dataclass -class AnimateDiffControlNetPipelineOutput(BaseOutput): - frames: Union[torch.Tensor, np.ndarray] + return outputs class AnimateDiffControlNetPipeline( @@ -843,8 +845,8 @@ def __call__( Examples: Returns: - [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is + [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. """ @@ -1020,7 +1022,7 @@ def __call__( ] controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) - # Denoising loop + # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -1096,21 +1098,17 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) + # 9. Post processing if output_type == "latent": - return AnimateDiffControlNetPipelineOutput(frames=latents) - - # Post-processing - video_tensor = self.decode_latents(latents) - - if output_type == "pt": - video = video_tensor + video = latents else: + video_tensor = self.decode_latents(latents) video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) - # Offload all models + # 10. Offload all models self.maybe_free_model_hooks() if not return_dict: return (video,) - return AnimateDiffControlNetPipelineOutput(frames=video) + return AnimateDiffPipelineOutput(frames=video) diff --git a/examples/community/pipeline_animatediff_img2video.py b/examples/community/pipeline_animatediff_img2video.py index fecb7211c7b9..d9209122262f 100644 --- a/examples/community/pipeline_animatediff_img2video.py +++ b/examples/community/pipeline_animatediff_img2video.py @@ -158,10 +158,8 @@ def slerp( return v2 +# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid def tensor2vid(video: torch.Tensor, processor, output_type="np"): - # Based on: - # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 - batch_size, channels, num_frames, height, width = video.shape outputs = [] for batch_idx in range(batch_size): @@ -170,6 +168,15 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): outputs.append(batch_output) + if output_type == "np": + outputs = np.stack(outputs) + + elif output_type == "pt": + outputs = torch.stack(outputs) + + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") + return outputs @@ -826,8 +833,8 @@ def __call__( Examples: Returns: - [`AnimateDiffPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`AnimateDiffPipelineOutput`] is + [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. """ # 0. Default height and width to unet @@ -958,11 +965,10 @@ def __call__( return AnimateDiffPipelineOutput(frames=latents) # 10. Post-processing - video_tensor = self.decode_latents(latents) - - if output_type == "pt": - video = video_tensor + if output_type == "latent": + video = latents else: + video_tensor = self.decode_latents(latents) video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) # 11. Offload all models diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index ce2f6585c601..cd7f0a283b63 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -81,7 +81,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: outputs = torch.stack(outputs) elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") return outputs @@ -668,8 +668,8 @@ def __call__( Examples: Returns: - [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is + [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. """ @@ -790,6 +790,8 @@ def __call__( self._num_timesteps = len(timesteps) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 8. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance @@ -829,13 +831,14 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) + # 9. Post processing if output_type == "latent": - return AnimateDiffPipelineOutput(frames=latents) - - video_tensor = self.decode_latents(latents) - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + video = latents + else: + video_tensor = self.decode_latents(latents) + video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) - # 9. Offload all models + # 10. Offload all models self.maybe_free_model_hooks() if not return_dict: diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index bcfcd3a24b5d..cb6b71351faf 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -100,7 +100,7 @@ def tensor2vid(video: torch.Tensor, processor, output_type="np"): outputs = torch.stack(outputs) elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") return outputs @@ -828,8 +828,8 @@ def __call__( Examples: Returns: - [`AnimateDiffPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`AnimateDiffPipelineOutput`] is + [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`pipelines.animatediff.pipeline_output.AnimateDiffPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. """ @@ -942,6 +942,7 @@ def __call__( self._num_timesteps = len(timesteps) num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + # 8. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -980,15 +981,11 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if output_type == "latent": - return AnimateDiffPipelineOutput(frames=latents) - # 9. Post-processing - video_tensor = self.decode_latents(latents) - - if output_type == "pt": - video = video_tensor + if output_type == "latent": + video = latents else: + video_tensor = self.decode_latents(latents) video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) # 10. Offload all models diff --git a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py index 2df21533962c..cb6f3e300904 100644 --- a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +++ b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py @@ -83,7 +83,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: outputs = torch.stack(outputs) elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") return outputs @@ -726,13 +726,14 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + # 8. Post processing if output_type == "latent": - return I2VGenXLPipelineOutput(frames=latents) - - video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size) - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + video = latents + else: + video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size) + video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) - # Offload all models + # 9. Offload all models self.maybe_free_model_hooks() if not return_dict: diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index bd3e2891f0d6..507088991a5e 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -107,7 +107,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: outputs = torch.stack(outputs) elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") return outputs @@ -860,8 +860,8 @@ def __call__( Examples: Returns: - [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is + [`~pipelines.pia.pipeline_pia.PIAPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.pia.pipeline_pia.PIAPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated frames. """ # 0. Default height and width to unet @@ -1018,13 +1018,14 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + # 9. Post processing if output_type == "latent": - return PIAPipelineOutput(frames=latents) - - video_tensor = self.decode_latents(latents) - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + video = latents + else: + video_tensor = self.decode_latents(latents) + video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) - # 9. Offload all models + # 10. Offload all models self.maybe_free_model_hooks() if not return_dict: diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index f53ebbafee2e..5cc4024a4acc 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -57,7 +57,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: outputs = torch.stack(outputs) elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") return outputs diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 0ed0765703f2..49d7cd54656f 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -76,7 +76,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: outputs = torch.stack(outputs) elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") return outputs @@ -646,13 +646,14 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + # 8. Post processing if output_type == "latent": - return TextToVideoSDPipelineOutput(frames=latents) - - video_tensor = self.decode_latents(latents) - video = tensor2vid(video_tensor, self.image_processor, output_type) + video = latents + else: + video_tensor = self.decode_latents(latents) + video = tensor2vid(video_tensor, self.image_processor, output_type) - # Offload all models + # 9. Offload all models self.maybe_free_model_hooks() if not return_dict: diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index 40c486316e13..4b088a682498 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -111,7 +111,7 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: outputs = torch.stack(outputs) elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil]") + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") return outputs @@ -694,13 +694,13 @@ def __call__( timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - # 5. Prepare latent variables + # 6. Prepare latent variables latents = self.prepare_latents(video, latent_timestep, batch_size, prompt_embeds.dtype, device, generator) - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7. Denoising loop + # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -740,20 +740,18 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) - if output_type == "latent": - return TextToVideoSDPipelineOutput(frames=latents) - # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") + # 9. Post processing if output_type == "latent": - return TextToVideoSDPipelineOutput(frames=latents) - - video_tensor = self.decode_latents(latents) - video = tensor2vid(video_tensor, self.image_processor, output_type) + video = latents + else: + video_tensor = self.decode_latents(latents) + video = tensor2vid(video_tensor, self.image_processor, output_type) - # Offload all models + # 10. Offload all models self.maybe_free_model_hooks() if not return_dict: