diff --git a/examples/video-comprehension/README.md b/examples/video-comprehension/README.md index da54f26740..68d0100508 100644 --- a/examples/video-comprehension/README.md +++ b/examples/video-comprehension/README.md @@ -31,3 +31,12 @@ python3 run_example.py \ ``` Models that have been validated: - [LanguageBind/Video-LLaVA-7B-hf ](https://huggingface.co/LanguageBind/Video-LLaVA-7B-hf) + +CogvideoX test: +```bash +python3 run_example.py \ + --model_name_or_path THUDM/CogVideoX-2b \ + --pipeline_type 'cogvideox' \ + --output_dir 'cogvideo_out' \ +``` + diff --git a/examples/video-comprehension/run_example.py b/examples/video-comprehension/run_example.py index 5868bea3e8..d22d6dcdec 100644 --- a/examples/video-comprehension/run_example.py +++ b/examples/video-comprehension/run_example.py @@ -31,6 +31,11 @@ adapt_transformers_to_gaudi, ) +from optimum.habana.diffusers import GaudiCogVideoXPipeline +from optimum.habana.transformers.gaudi_configuration import GaudiConfig +from diffusers.utils.export_utils import export_to_video + + logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -52,6 +57,34 @@ def read_video_pyav(container, indices): frames.append(frame) return np.stack([x.to_ndarray(format="rgb24") for x in frames]) +def cogvideoX_generate(args): + gaudi_config_kwargs = {"use_fused_adam": True, "use_fused_clip_norm": True} + gaudi_config_kwargs["use_torch_autocast"] = True + + gaudi_config = GaudiConfig(**gaudi_config_kwargs) + logger.info(f"Gaudi Config: {gaudi_config}") + + kwargs = { + "use_habana": True, + "use_hpu_graphs": True, + "gaudi_config": gaudi_config, + } + kwargs["torch_dtype"] = torch.bfloat16 + pipeline: GaudiCogVideoXPipeline = GaudiCogVideoXPipeline.from_pretrained(args.model_name_or_path, **kwargs) + pipeline.vae.enable_tiling() + pipeline.vae.enable_slicing() + video = pipeline( + prompt=args.prompt, + num_videos_per_prompt=1, + num_inference_steps=50, + num_frames=49, + guidance_scale=6, + generator=torch.Generator(device="cpu").manual_seed(42), + ).frames[0] + video_save_dir = Path(args.output_dir) + video_save_dir.mkdir(parents=True, exist_ok=True) + filename = video_save_dir / "cogvideoX_out.mp4" + export_to_video(video, str(filename.resolve()), fps=8) def main(): parser = argparse.ArgumentParser() @@ -63,6 +96,14 @@ def main(): help="Path to pre-trained model", ) parser.add_argument( + "--pipeline_type", + type=str, + nargs="*", + default="sdp", + help="pipeline type:sdp or cogvideoX", + ) + parser.add_argument( + "--video_path", default=None, type=str, @@ -120,6 +161,10 @@ def main(): args = parser.parse_args() + if args.pipeline_type[0] == "cogvideox": + cogvideoX_generate(args) + return None + os.environ.setdefault("EXPERIMENTAL_WEIGHT_SHARING", "FALSE") if args.video_path is None: diff --git a/optimum/habana/diffusers/__init__.py b/optimum/habana/diffusers/__init__.py index bd4ceb45ee..fcd35aed9c 100644 --- a/optimum/habana/diffusers/__init__.py +++ b/optimum/habana/diffusers/__init__.py @@ -1,4 +1,5 @@ from .pipelines.auto_pipeline import AutoPipelineForInpainting, AutoPipelineForText2Image +from .pipelines.cogvideox.pipeline_cogvideox_gaudi import GaudiCogVideoXPipeline from .pipelines.controlnet.pipeline_controlnet import GaudiStableDiffusionControlNetPipeline from .pipelines.controlnet.pipeline_stable_video_diffusion_controlnet import ( GaudiStableVideoDiffusionControlNetPipeline, diff --git a/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py b/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py new file mode 100644 index 0000000000..e77adc1c9a --- /dev/null +++ b/optimum/habana/diffusers/pipelines/cogvideox/pipeline_cogvideox_gaudi.py @@ -0,0 +1,846 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from diffusers import CogVideoXPipeline +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from diffusers.models.attention import Attention +from diffusers.models.autoencoders.autoencoder_kl_cogvideox import CogVideoXCausalConv3d +from diffusers.models.autoencoders.vae import DecoderOutput +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from diffusers.utils import ( + USE_PEFT_BACKEND, + BaseOutput, + is_torch_version, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import randn_tensor +from transformers import T5EncoderModel, T5Tokenizer + +from optimum.habana.diffusers.pipelines.pipeline_utils import GaudiDiffusionPipeline +from optimum.habana.transformers.gaudi_configuration import GaudiConfig + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA +except ImportError: + print("Not using HPU fused scaled dot-product attention kernel.") + FusedSDPA = None + + +# FusedScaledDotProductAttention +class ModuleFusedSDPA(torch.nn.Module): + def __init__(self, fusedSDPA): + super().__init__() + self._hpu_kernel_fsdpa = fusedSDPA + + def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode): + return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode) + + +def apply_rotary_emb( + x: torch.Tensor, + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Adapted from: https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/models/embeddings.py#L697 + """ + cos_, sin_ = freqs_cis # [S, D] + + cos = cos_[None, None] + sin = sin_[None, None] + cos, sin = cos.to(x.device), sin.to(x.device) + + x = torch.ops.hpu.rotary_pos_embedding(x, sin, cos, None, 0, 1) + + return x + + +class CogVideoXAttnProcessorGaudi: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) + if not attn.is_cross_attention: + key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) + + hidden_states = self.fused_scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_casual=False, + scale=None, + softmax_mode="fast", + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + return hidden_states, encoder_hidden_states + + +def cogvideoXTransformerForwardGaudi( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + timestep_cond: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, +): + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective.") + + batch_size, num_frames, channels, height, width = hidden_states.shape + + # 1. Time embedding + timesteps = timestep + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + # 2. Patch embedding + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + hidden_states = self.embedding_dropout(hidden_states) + + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + import habana_frameworks.torch.core as htcore + + # 3. Transformer blocks + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + htcore.mark_step() + + if not self.config.use_rotary_positional_embeddings: + # CogVideoX-2B + hidden_states = self.norm_final(hidden_states) + else: + # CogVideoX-5B + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states = self.norm_final(hidden_states) + hidden_states = hidden_states[:, text_seq_length:] + + # 4. Final block + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + # Note: we use `-1` instead of `channels`: + # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels) + # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels) + p = self.config.patch_size + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + +def tiled_decode_gaudi(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + # Rough memory assessment: + # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers. + # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720]. + # - Assume fp16 (2 bytes per value). + # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB + # + # Memory assessment when using tiling: + # - Assume everything as above but now HxW is 240x360 by tiling in half + # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB + + print("run gaudi pipelined tiled decode!") + batch_size, num_channels, num_frames, height, width = z.shape + + overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) + overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width)) + blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height) + blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) + row_limit_height = self.tile_sample_min_height - blend_extent_height + row_limit_width = self.tile_sample_min_width - blend_extent_width + frame_batch_size = self.num_latent_frames_batch_size + + import habana_frameworks.torch.core as htcore + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + num_batches = max(num_frames // frame_batch_size, 1) + conv_cache = None + time = [] + + for k in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) + end_frame = frame_batch_size * (k + 1) + remaining_frames + tile = z[ + :, + :, + start_frame:end_frame, + i : i + self.tile_latent_min_height, + j : j + self.tile_latent_min_width, + ].clone() + if self.post_quant_conv is not None: + tile = self.post_quant_conv(tile) + tile, conv_cache = self.decoder(tile, conv_cache=conv_cache) + time.append(tile.clone()) + htcore.mark_step() + + row.append(torch.cat(time, dim=2)) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + +def CogVideoXCausalConv3dforwardGaudi( + self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None +) -> torch.Tensor: + # print('run gaudi CogVideoXCausalConv3d forward!') + inputs = self.fake_context_parallel_forward(inputs, conv_cache) + # conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() + + padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad) + inputs_pad = F.pad(inputs, padding_2d, mode="constant", value=0) + + output = self.conv(inputs_pad) + if self.time_kernel_size > 1: + if conv_cache is not None and conv_cache.shape == inputs[:, :, -self.time_kernel_size + 1 :].shape: + conv_cache.copy_(inputs[:, :, -self.time_kernel_size + 1 :]) + else: + conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone() + return output, conv_cache + + +setattr(CogVideoXCausalConv3d, "forward", CogVideoXCausalConv3dforwardGaudi) +setattr(AutoencoderKLCogVideoX, "tiled_decode", tiled_decode_gaudi) + + +@dataclass +class GaudiTextToVideoSDPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class GaudiCogVideoXPipeline(GaudiDiffusionPipeline, CogVideoXPipeline): + r""" + Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py#L84 + """ + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + use_habana: bool = False, + use_hpu_graphs: bool = False, + gaudi_config: Union[str, GaudiConfig] = None, + bf16_full_eval: bool = False, + ): + GaudiDiffusionPipeline.__init__( + self, + use_habana, + use_hpu_graphs, + gaudi_config, + bf16_full_eval, + ) + CogVideoXPipeline.__init__( + self, + tokenizer, + text_encoder, + vae, + transformer, + scheduler, + ) + self.to(self._device) + self.transformer.forward = cogvideoXTransformerForwardGaudi + for block in self.transformer.transformer_blocks: + block.attn1.set_processor(CogVideoXAttnProcessorGaudi()) + + from habana_frameworks.torch.hpu import wrap_in_hpu_graph + + self.vae.decoder = wrap_in_hpu_graph(self.vae.decoder) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + def enable_model_cpu_offload(self, *args, **kwargs): + if self.use_habana: + raise NotImplementedError("enable_model_cpu_offload() is not implemented for HPU") + else: + return super().enable_model_cpu_offload(*args, **kwargs) + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + # torch.randn is broken on HPU so running it on CPU + rand_device = "cpu" if device.type == "hpu" else device + rand_device = torch.device(rand_device) + latents = randn_tensor(shape, generator=generator, device=rand_device, dtype=dtype).to(device) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=self.gaudi_config.use_torch_autocast): + if num_frames > 49: + raise ValueError( + "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Default height and width to unet + height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + + # 5. Prepare latent variables + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + outputs = [] + import habana_frameworks.torch.core as htcore + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer_hpu( + latent_model_input=latent_model_input, + prompt_embeds=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + ) + + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) + / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[ + 0 + ] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + if not self.use_hpu_graphs: + htcore.mark_step() + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if not self.use_hpu_graphs: + htcore.mark_step() + + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return GaudiTextToVideoSDPipelineOutput(frames=video) + + @torch.no_grad() + def transformer_hpu(self, latent_model_input, prompt_embeds, timestep, image_rotary_emb): + if self.use_hpu_graphs: + return self.capture_replay(latent_model_input, prompt_embeds, timestep, image_rotary_emb) + else: + return self.transformer( + self.transformer, + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + @torch.no_grad() + def capture_replay(self, latent_model_input, prompt_embeds, timestep, image_rotary_emb): + inputs = [latent_model_input.clone(), prompt_embeds.clone(), timestep.clone(), image_rotary_emb, False] + h = self.ht.hpu.graphs.input_hash(inputs) + cached = self.cache.get(h) + + if cached is None: + # Capture the graph and cache it + with self.ht.hpu.stream(self.hpu_stream): + graph = self.ht.hpu.HPUGraph() + graph.capture_begin() + outputs = self.transformer( + self.transformer, + hidden_states=inputs[0], + encoder_hidden_states=inputs[1], + timestep=inputs[2], + image_rotary_emb=inputs[3], + return_dict=inputs[4], + )[0] + graph.capture_end() + graph_inputs = inputs + graph_outputs = outputs + self.cache[h] = self.ht.hpu.graphs.CachedParams(graph_inputs, graph_outputs, graph) + return outputs + + # Replay the cached graph with updated inputs + self.ht.hpu.graphs.copy_to(cached.graph_inputs, inputs) + cached.graph.replay() + self.ht.core.hpu.default_stream().synchronize() + + return cached.graph_outputs diff --git a/tests/test_diffusers.py b/tests/test_diffusers.py index fe3a9aba11..c076828cfa 100644 --- a/tests/test_diffusers.py +++ b/tests/test_diffusers.py @@ -39,8 +39,11 @@ import torch from diffusers import ( AutoencoderKL, + AutoencoderKLCogVideoX, AutoencoderKLTemporalDecoder, AutoencoderTiny, + CogVideoXDDIMScheduler, + CogVideoXTransformer3DModel, ControlNetModel, DiffusionPipeline, EulerAncestralDiscreteScheduler, @@ -87,12 +90,15 @@ DPTConfig, DPTFeatureExtractor, DPTForDepthEstimation, + T5Config, T5EncoderModel, + T5Tokenizer, ) from transformers.testing_utils import parse_flag_from_env, slow from optimum.habana import GaudiConfig from optimum.habana.diffusers import ( + GaudiCogVideoXPipeline, GaudiDDIMScheduler, GaudiDDPMPipeline, GaudiDiffusionPipeline, @@ -3795,6 +3801,141 @@ def test_stable_diffusion_xl_img2img_euler(self): self.assertLess(np.abs(image_slice.flatten() - expected_slice).max(), 1e-2) + +class GaudiCogVideoXPipelineTester(TestCase): + """ + Tests the TextToVideoSDPipeline for Gaudi. + Adapted from https://github.com/huggingface/diffusers/blob/v0.24.0-release/tests/pipelines/text_to_video_synthesis/test_text_to_video.py + """ + + def get_dummy_components(self): + tokenizer = T5Tokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + set_seed(0) + text_encoder_cfg = T5Config( + vocab_size=32128, + d_kv=64, + d_ff=10240, + num_layers=8, + num_decoder_layers=8, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + initializer_factor=1.0, + feed_forward_proj="gated-gelu", + is_encoder_decoder=True, + pad_token_id=0, + eos_token_id=1, + torch_dtype=torch.bfloat16, + d_model=4096, + ) + text_encoder = T5EncoderModel(text_encoder_cfg).bfloat16() + + set_seed(0) + transformer = CogVideoXTransformer3DModel( + num_attention_heads=30, + attention_head_dim=64, + in_channels=16, + out_channels=16, + flip_sin_to_cos=True, + freq_shift=0, + time_embed_dim=512, + text_embed_dim=4096, + num_layers=8, + dropout=0.0, + attention_bias=True, + sample_width=90, + sample_height=60, + sample_frames=49, + patch_size=2, + temporal_compression_ratio=4, + max_text_seq_length=226, + activation_fn="gelu-approximate", + timestep_activation_fn="silu", + norm_elementwise_affine=True, + norm_eps=1e-5, + spatial_interpolation_scale=1.875, + temporal_interpolation_scale=1.0, + ).bfloat16() + + scheduler = CogVideoXDDIMScheduler( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.0120, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=True, + steps_offset=0, + prediction_type="v_prediction", + clip_sample_range=1.0, + sample_max_value=1.0, + timestep_spacing="trailing", + rescale_betas_zero_snr=True, + snr_shift_scale=1.0, + ) + + set_seed(0) + vae = AutoencoderKLCogVideoX( + in_channels=3, + out_channels=3, + down_block_types=[ + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + ], + block_out_channels=[128, 256, 256, 512], + latent_channels=16, + layers_per_block=1, + act_fn="silu", + norm_eps=1e-6, + norm_num_groups=32, + temporal_compression_ratio=4, + sample_height=480, + sample_width=720, + scaling_factor=1.15258426, + ).bfloat16() + + vae.enable_slicing() + vae.enable_tiling() + + components = { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "transformer": transformer, + "scheduler": scheduler, + "vae": vae, + } + + return components + + def get_dummy_inputs(self): + prompts = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance." + return prompts + + def test_cogvideoX_default_case(self): + gaudi_config_kwargs = {"use_fused_adam": True, "use_fused_clip_norm": True} + gaudi_config_kwargs["use_torch_autocast"] = True + gaudi_config = GaudiConfig(**gaudi_config_kwargs) + + components = self.get_dummy_components() + components["use_habana"] = True + components["use_hpu_graphs"] = True + components["gaudi_config"] = gaudi_config + + prompts = self.get_dummy_inputs() + cogVideoX_pipe = GaudiCogVideoXPipeline(**components) + video = cogVideoX_pipe( + prompt=prompts, + num_videos_per_prompt=1, + num_inference_steps=5, + num_frames=49, + guidance_scale=6, + generator=torch.Generator(device="cpu").manual_seed(42), + ).frames[0] + + self.assertIsNotNone(video) + self.assertEqual(49 == len(video)) + + class GaudiTextToVideoSDPipelineTester(TestCase): """ Tests the TextToVideoSDPipeline for Gaudi. @@ -3802,6 +3943,7 @@ class GaudiTextToVideoSDPipelineTester(TestCase): """ def get_dummy_components(self): + set_seed(0) unet = UNet3DConditionModel( block_out_channels=(4, 8), @@ -4224,7 +4366,6 @@ def test_pipeline_call_signature(self): for k, v in parameters.items(): if v.default != inspect._empty: optional_parameters.add(k) - parameters = set(parameters.keys()) parameters.remove("self") parameters.discard("kwargs") # kwargs can be added if arguments of pipeline call function are deprecated @@ -4294,7 +4435,6 @@ def _test_inference_batch_consistent( if "batch_size" in inputs: batched_input["batch_size"] = batch_size - batched_inputs.append(batched_input) logger.setLevel(level=diffusers.logging.WARNING) for batch_size, batched_input in zip(batch_sizes, batched_inputs): @@ -5035,7 +5175,6 @@ def test_stable_diffusion_inpaint_no_throughput_regression(self): mask_image = load_image( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png" ) - prompts = [ "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k", ]