diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 76afb4b37..1275b13cd 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -36,6 +36,15 @@ concurrency: cancel-in-progress: true jobs: + pre-commit: + runs-on: [ubuntu-latest] + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: '3.10' + - uses: pre-commit/action@v3.0.1 upload_src: if: github.repository == 'siliconflow/onediff' runs-on: [ubuntu-latest] diff --git a/onediff_comfy_nodes/modules/booster_cache.py b/onediff_comfy_nodes/modules/booster_cache.py index b6c4e5347..bcc766063 100644 --- a/onediff_comfy_nodes/modules/booster_cache.py +++ b/onediff_comfy_nodes/modules/booster_cache.py @@ -7,6 +7,7 @@ from onediff.torch_utils.module_operations import get_sub_module from onediff.utils.import_utils import is_oneflow_available + @singledispatch def switch_to_cached_model(new_model, cached_model): raise NotImplementedError(type(new_model)) diff --git a/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/CrossAttentionPatch.py b/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/CrossAttentionPatch.py index 749029178..cf9cf73e7 100644 --- a/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/CrossAttentionPatch.py +++ b/onediff_comfy_nodes/modules/oneflow/infer_compiler_registry/register_comfy/CrossAttentionPatch.py @@ -26,7 +26,7 @@ def tensor_to_size(source, dest_size): return source -def get_weight_subidxs(weight,ad_params,sub_idxs): +def get_weight_subidxs(weight, ad_params, sub_idxs): return weight[ad_params[sub_idxs]] @@ -167,7 +167,7 @@ def ipadapter_attention( if ad_params is not None and ad_params["sub_idxs"] is not None: if isinstance(weight, torch.Tensor) and weight.dim() != 0: weight = tensor_to_size(weight, ad_params["full_length"]) - weight = get_weight_subidxs(weight,ad_params,"sub_idxs") + weight = get_weight_subidxs(weight, ad_params, "sub_idxs") # if torch.all(weight == 0): # return 0 weight = weight.repeat( @@ -178,8 +178,8 @@ def ipadapter_attention( # if image length matches or exceeds full_length get sub_idx images if cond.shape[0] >= ad_params["full_length"]: - cond = get_weight_subidxs(cond,ad_params,"sub_idxs") - uncond = get_weight_subidxs(uncond,ad_params,"sub_idxs") + cond = get_weight_subidxs(cond, ad_params, "sub_idxs") + uncond = get_weight_subidxs(uncond, ad_params, "sub_idxs") # otherwise get sub_idxs images else: cond = tensor_to_size(cond, ad_params["full_length"]) diff --git a/onediff_diffusers_extensions/examples/kolors/README.md b/onediff_diffusers_extensions/examples/kolors/README.md index d179afb43..1b6422aa9 100644 --- a/onediff_diffusers_extensions/examples/kolors/README.md +++ b/onediff_diffusers_extensions/examples/kolors/README.md @@ -118,4 +118,3 @@ python3 onediff_diffusers_extensions/examples/kolors/text_to_image_kolors.py \ The quality report for accelerating the kolors model with onediff is located at: https://github.com/siliconflow/odeval/tree/main/models/kolors - diff --git a/onediff_diffusers_extensions/examples/kolors/text_to_image_kolors.py b/onediff_diffusers_extensions/examples/kolors/text_to_image_kolors.py index 051e67443..fb43f93a9 100644 --- a/onediff_diffusers_extensions/examples/kolors/text_to_image_kolors.py +++ b/onediff_diffusers_extensions/examples/kolors/text_to_image_kolors.py @@ -2,10 +2,11 @@ import json import time +import torch + from diffusers import DPMSolverMultistepScheduler, KolorsPipeline -from onediffx import compile_pipe, quantize_pipe, load_pipe, save_pipe from onediff.infer_compiler import oneflow_compile -import torch +from onediffx import compile_pipe, load_pipe, quantize_pipe, save_pipe def parse_args(): @@ -103,7 +104,7 @@ def __init__( elif compiler == "oneflow": print("oneflow backend compile...") # self.pipe.unet = self.oneflow_compile(self.pipe.unet) - self.pipe = compile_pipe(self.pipe, ignores=['text_encoder', 'vae']) + self.pipe = compile_pipe(self.pipe, ignores=["text_encoder", "vae"]) def warmup(self, gen_args, warmup_iterations): warmup_args = gen_args.copy() diff --git a/src/infer_compiler_registry/register_diffusers/transformer_2d/v_0_28.py b/src/infer_compiler_registry/register_diffusers/transformer_2d/v_0_28.py index dd8c82d88..2b74b98fe 100644 --- a/src/infer_compiler_registry/register_diffusers/transformer_2d/v_0_28.py +++ b/src/infer_compiler_registry/register_diffusers/transformer_2d/v_0_28.py @@ -1,6 +1,8 @@ +from typing import Any, Dict, Optional + import torch -from typing import Optional, Dict, Any from onediff.infer_compiler.backends.oneflow.transform import transform_mgr + transformed_diffusers = transform_mgr.transform_package("diffusers") ConfigMixin = transformed_diffusers.configuration_utils.ConfigMixin register_to_config = transformed_diffusers.configuration_utils.register_to_config @@ -20,11 +22,14 @@ LoRACompatibleLinear = transformed_diffusers.models.lora.LoRACompatibleLinear ModelMixin = transformed_diffusers.models.modeling_utils.ModelMixin AdaLayerNormSingle = transformed_diffusers.models.normalization.AdaLayerNormSingle -Transformer2DModelOutput = transformed_diffusers.models.transformers.transformer_2d.Transformer2DModelOutput +Transformer2DModelOutput = ( + transformed_diffusers.models.transformers.transformer_2d.Transformer2DModelOutput +) proxy_Transformer2DModel = ( transformed_diffusers.models.transformers.transformer_2d.Transformer2DModel ) + class Transformer2DModel(proxy_Transformer2DModel): def forward( self, @@ -79,7 +84,9 @@ def forward( """ if cross_attention_kwargs is not None: if cross_attention_kwargs.get("scale", None) is not None: - logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + logger.warning( + "Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored." + ) # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. @@ -100,7 +107,9 @@ def forward( # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: - encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(hidden_states.dtype) + ) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) hidden_states_in = hidden_states @@ -113,8 +122,16 @@ def forward( elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) elif self.is_input_patches: - height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size - hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs( + height, width = ( + hidden_states.shape[-2] // self.patch_size, + hidden_states.shape[-1] // self.patch_size, + ) + ( + hidden_states, + encoder_hidden_states, + timestep, + embedded_timestep, + ) = self._operate_on_patched_inputs( hidden_states, encoder_hidden_states, timestep, added_cond_kwargs ) @@ -131,7 +148,9 @@ def custom_forward(*inputs): return custom_forward - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, @@ -203,7 +222,9 @@ def custom_forward(*inputs): return Transformer2DModelOutput(sample=output) - def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim): + def _get_output_for_continuous_inputs( + self, hidden_states, residual, batch_size, height, width, inner_dim + ): # # hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() # hidden_states = ( # hidden_states.permute(0, 2, 1) @@ -215,13 +236,17 @@ def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, return if not self.use_linear_projection: hidden_states = ( - hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states.reshape(batch_size, height, width, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() ) hidden_states = self.proj_out(hidden_states) else: hidden_states = self.proj_out(hidden_states) hidden_states = ( - hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states.reshape(batch_size, height, width, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() ) output = hidden_states + residual @@ -238,7 +263,13 @@ def _get_output_for_vectorized_inputs(self, hidden_states): return output def _get_output_for_patched_inputs( - self, hidden_states, timestep, class_labels, embedded_timestep, height=None, width=None + self, + hidden_states, + timestep, + class_labels, + embedded_timestep, + height=None, + width=None, ): raise # import ipdb; ipdb.set_trace() @@ -247,10 +278,14 @@ def _get_output_for_patched_inputs( timestep, class_labels, hidden_dtype=hidden_states.dtype ) shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = ( + self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + ) hidden_states = self.proj_out_2(hidden_states) elif self.config.norm_type == "ada_norm_single": - shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + shift, scale = ( + self.scale_shift_table[None] + embedded_timestep[:, None] + ).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) # Modulation hidden_states = hidden_states * (1 + scale) + shift @@ -261,11 +296,23 @@ def _get_output_for_patched_inputs( if self.adaln_single is None: height = width = int(hidden_states.shape[1] ** 0.5) hidden_states = hidden_states.reshape( - shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + shape=( + -1, + height, + width, + self.patch_size, + self.patch_size, + self.out_channels, + ) ) hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) output = hidden_states.reshape( - shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + shape=( + -1, + self.out_channels, + height * self.patch_size, + width * self.patch_size, + ) ) return output @@ -284,4 +331,4 @@ def _operate_on_continuous_inputs(self, hidden_states): hidden_states = hidden_states.permute(0, 2, 3, 1).flatten(1, 2) hidden_states = self.proj_in(hidden_states) - return hidden_states, inner_dim \ No newline at end of file + return hidden_states, inner_dim diff --git a/src/infer_compiler_registry/register_diffusers/transformer_2d_oflow.py b/src/infer_compiler_registry/register_diffusers/transformer_2d_oflow.py index 213eebd59..f30893d4d 100644 --- a/src/infer_compiler_registry/register_diffusers/transformer_2d_oflow.py +++ b/src/infer_compiler_registry/register_diffusers/transformer_2d_oflow.py @@ -921,7 +921,7 @@ def forward( ) proxy_Transformer2DModel = ( transformed_diffusers.models.transformer_2d.Transformer2DModel - ) + ) class Transformer2DModel(proxy_Transformer2DModel): def forward( @@ -1213,5 +1213,6 @@ def custom_forward(*inputs): return (output,) return Transformer2DModelOutput(sample=output) + else: from .transformer_2d.v_0_28 import Transformer2DModel, Transformer2DModelOutput diff --git a/src/infer_compiler_registry/register_diffusers/unet_2d_blocks_oflow.py b/src/infer_compiler_registry/register_diffusers/unet_2d_blocks_oflow.py index 6009adfb6..f4ac7340c 100644 --- a/src/infer_compiler_registry/register_diffusers/unet_2d_blocks_oflow.py +++ b/src/infer_compiler_registry/register_diffusers/unet_2d_blocks_oflow.py @@ -43,9 +43,7 @@ def forward( else: hidden_states = upsampler(hidden_states) - class CrossAttnUpBlock2D( - diffusers_unet_2d_blocks.CrossAttnUpBlock2D - ): + class CrossAttnUpBlock2D(diffusers_unet_2d_blocks.CrossAttnUpBlock2D): def forward( self, hidden_states: torch.FloatTensor, @@ -185,9 +183,7 @@ def forward( return hidden_states - class CrossAttnUpBlock2D( - diffusers_unet_2d_blocks.CrossAttnUpBlock2D - ): + class CrossAttnUpBlock2D(diffusers_unet_2d_blocks.CrossAttnUpBlock2D): def forward( self, hidden_states: torch.FloatTensor, diff --git a/src/infer_compiler_registry/register_diffusers/unet_2d_condition_oflow.py b/src/infer_compiler_registry/register_diffusers/unet_2d_condition_oflow.py index 49f59ce1f..cb30bf835 100644 --- a/src/infer_compiler_registry/register_diffusers/unet_2d_condition_oflow.py +++ b/src/infer_compiler_registry/register_diffusers/unet_2d_condition_oflow.py @@ -14,12 +14,16 @@ UNet2DConditionOutput = ( transformed_diffusers.models.unet_2d_condition.UNet2DConditionOutput ) - proxy_UNet2DConditionModel = transformed_diffusers.models.unet_2d_condition.UNet2DConditionModel + proxy_UNet2DConditionModel = ( + transformed_diffusers.models.unet_2d_condition.UNet2DConditionModel + ) else: UNet2DConditionOutput = ( transformed_diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput ) - proxy_UNet2DConditionModel = transformed_diffusers.models.unets.unet_2d_condition.UNet2DConditionModel + proxy_UNet2DConditionModel = ( + transformed_diffusers.models.unets.unet_2d_condition.UNet2DConditionModel + ) try: USE_PEFT_BACKEND = transformed_diffusers.utils.USE_PEFT_BACKEND diff --git a/src/onediff/infer_compiler/backends/oneflow/args_tree_util.py b/src/onediff/infer_compiler/backends/oneflow/args_tree_util.py index fdac2a081..8fbe9b7d9 100644 --- a/src/onediff/infer_compiler/backends/oneflow/args_tree_util.py +++ b/src/onediff/infer_compiler/backends/oneflow/args_tree_util.py @@ -1,6 +1,7 @@ import torch import oneflow as flow # usort: skip import functools + from oneflow.framework.args_tree import ArgsTree from onediff.utils import logger diff --git a/src/onediff/infer_compiler/backends/oneflow/import_tools/patch_for_compiler.py b/src/onediff/infer_compiler/backends/oneflow/import_tools/patch_for_compiler.py index d696e0153..1dea18a98 100644 --- a/src/onediff/infer_compiler/backends/oneflow/import_tools/patch_for_compiler.py +++ b/src/onediff/infer_compiler/backends/oneflow/import_tools/patch_for_compiler.py @@ -16,7 +16,13 @@ def mem_get_info(dev): @staticmethod def _scaled_dot_product_attention_math( - query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, ): scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale @@ -51,7 +57,13 @@ def _scaled_dot_product_attention_math( @staticmethod def scaled_dot_product_attention( - query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, ): """Scaled Dot-Product Attention Args: diff --git a/src/onediff/infer_compiler/backends/oneflow/online_quantization_utils.py b/src/onediff/infer_compiler/backends/oneflow/online_quantization_utils.py index 3cc6effb5..cb8be21a8 100644 --- a/src/onediff/infer_compiler/backends/oneflow/online_quantization_utils.py +++ b/src/onediff/infer_compiler/backends/oneflow/online_quantization_utils.py @@ -1,5 +1,6 @@ import functools + def patch_input_adapter(in_args, in_kwargs): return in_args, in_kwargs