Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix format #1039

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onediff_comfy_nodes/modules/booster_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from onediff.torch_utils.module_operations import get_sub_module
from onediff.utils.import_utils import is_oneflow_available


@singledispatch
def switch_to_cached_model(new_model, cached_model):
raise NotImplementedError(type(new_model))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def tensor_to_size(source, dest_size):
return source


def get_weight_subidxs(weight,ad_params,sub_idxs):
def get_weight_subidxs(weight, ad_params, sub_idxs):
return weight[ad_params[sub_idxs]]


Expand Down Expand Up @@ -167,7 +167,7 @@ def ipadapter_attention(
if ad_params is not None and ad_params["sub_idxs"] is not None:
if isinstance(weight, torch.Tensor) and weight.dim() != 0:
weight = tensor_to_size(weight, ad_params["full_length"])
weight = get_weight_subidxs(weight,ad_params,"sub_idxs")
weight = get_weight_subidxs(weight, ad_params, "sub_idxs")
# if torch.all(weight == 0):
# return 0
weight = weight.repeat(
Expand All @@ -178,8 +178,8 @@ def ipadapter_attention(

# if image length matches or exceeds full_length get sub_idx images
if cond.shape[0] >= ad_params["full_length"]:
cond = get_weight_subidxs(cond,ad_params,"sub_idxs")
uncond = get_weight_subidxs(uncond,ad_params,"sub_idxs")
cond = get_weight_subidxs(cond, ad_params, "sub_idxs")
uncond = get_weight_subidxs(uncond, ad_params, "sub_idxs")
# otherwise get sub_idxs images
else:
cond = tensor_to_size(cond, ad_params["full_length"])
Expand Down
1 change: 0 additions & 1 deletion onediff_diffusers_extensions/examples/kolors/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,3 @@ python3 onediff_diffusers_extensions/examples/kolors/text_to_image_kolors.py \

The quality report for accelerating the kolors model with onediff is located at:
https://github.com/siliconflow/odeval/tree/main/models/kolors

Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import json
import time

import torch

from diffusers import DPMSolverMultistepScheduler, KolorsPipeline
from onediffx import compile_pipe, quantize_pipe, load_pipe, save_pipe
from onediff.infer_compiler import oneflow_compile
import torch
from onediffx import compile_pipe, load_pipe, quantize_pipe, save_pipe


def parse_args():
Expand Down Expand Up @@ -103,7 +104,7 @@ def __init__(
elif compiler == "oneflow":
print("oneflow backend compile...")
# self.pipe.unet = self.oneflow_compile(self.pipe.unet)
self.pipe = compile_pipe(self.pipe, ignores=['text_encoder', 'vae'])
self.pipe = compile_pipe(self.pipe, ignores=["text_encoder", "vae"])

def warmup(self, gen_args, warmup_iterations):
warmup_args = gen_args.copy()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any, Dict, Optional

import torch
from typing import Optional, Dict, Any
from onediff.infer_compiler.backends.oneflow.transform import transform_mgr

transformed_diffusers = transform_mgr.transform_package("diffusers")
ConfigMixin = transformed_diffusers.configuration_utils.ConfigMixin
register_to_config = transformed_diffusers.configuration_utils.register_to_config
Expand All @@ -20,11 +22,14 @@
LoRACompatibleLinear = transformed_diffusers.models.lora.LoRACompatibleLinear
ModelMixin = transformed_diffusers.models.modeling_utils.ModelMixin
AdaLayerNormSingle = transformed_diffusers.models.normalization.AdaLayerNormSingle
Transformer2DModelOutput = transformed_diffusers.models.transformers.transformer_2d.Transformer2DModelOutput
Transformer2DModelOutput = (
transformed_diffusers.models.transformers.transformer_2d.Transformer2DModelOutput
)
proxy_Transformer2DModel = (
transformed_diffusers.models.transformers.transformer_2d.Transformer2DModel
)


class Transformer2DModel(proxy_Transformer2DModel):
def forward(
self,
Expand Down Expand Up @@ -79,7 +84,9 @@ def forward(
"""
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
logger.warning(
"Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored."
)
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
Expand All @@ -100,7 +107,9 @@ def forward(

# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = (
1 - encoder_attention_mask.to(hidden_states.dtype)
) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)

hidden_states_in = hidden_states
Expand All @@ -113,8 +122,16 @@ def forward(
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
elif self.is_input_patches:
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs(
height, width = (
hidden_states.shape[-2] // self.patch_size,
hidden_states.shape[-1] // self.patch_size,
)
(
hidden_states,
encoder_hidden_states,
timestep,
embedded_timestep,
) = self._operate_on_patched_inputs(
hidden_states, encoder_hidden_states, timestep, added_cond_kwargs
)

Expand All @@ -131,7 +148,9 @@ def custom_forward(*inputs):

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
Expand Down Expand Up @@ -203,7 +222,9 @@ def custom_forward(*inputs):

return Transformer2DModelOutput(sample=output)

def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim):
def _get_output_for_continuous_inputs(
self, hidden_states, residual, batch_size, height, width, inner_dim
):
# # hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
# hidden_states = (
# hidden_states.permute(0, 2, 1)
Expand All @@ -215,13 +236,17 @@ def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size,
return
if not self.use_linear_projection:
hidden_states = (
hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states.reshape(batch_size, height, width, inner_dim)
.permute(0, 3, 1, 2)
.contiguous()
)
hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = (
hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states.reshape(batch_size, height, width, inner_dim)
.permute(0, 3, 1, 2)
.contiguous()
)

output = hidden_states + residual
Expand All @@ -238,7 +263,13 @@ def _get_output_for_vectorized_inputs(self, hidden_states):
return output

def _get_output_for_patched_inputs(
self, hidden_states, timestep, class_labels, embedded_timestep, height=None, width=None
self,
hidden_states,
timestep,
class_labels,
embedded_timestep,
height=None,
width=None,
):
raise
# import ipdb; ipdb.set_trace()
Expand All @@ -247,10 +278,14 @@ def _get_output_for_patched_inputs(
timestep, class_labels, hidden_dtype=hidden_states.dtype
)
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
hidden_states = (
self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
)
hidden_states = self.proj_out_2(hidden_states)
elif self.config.norm_type == "ada_norm_single":
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
shift, scale = (
self.scale_shift_table[None] + embedded_timestep[:, None]
).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
Expand All @@ -261,11 +296,23 @@ def _get_output_for_patched_inputs(
if self.adaln_single is None:
height = width = int(hidden_states.shape[1] ** 0.5)
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
shape=(
-1,
height,
width,
self.patch_size,
self.patch_size,
self.out_channels,
)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
shape=(
-1,
self.out_channels,
height * self.patch_size,
width * self.patch_size,
)
)
return output

Expand All @@ -284,4 +331,4 @@ def _operate_on_continuous_inputs(self, hidden_states):
hidden_states = hidden_states.permute(0, 2, 3, 1).flatten(1, 2)
hidden_states = self.proj_in(hidden_states)

return hidden_states, inner_dim
return hidden_states, inner_dim
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ def forward(
)
proxy_Transformer2DModel = (
transformed_diffusers.models.transformer_2d.Transformer2DModel
)
)

class Transformer2DModel(proxy_Transformer2DModel):
def forward(
Expand Down Expand Up @@ -1213,5 +1213,6 @@ def custom_forward(*inputs):
return (output,)

return Transformer2DModelOutput(sample=output)

else:
from .transformer_2d.v_0_28 import Transformer2DModel, Transformer2DModelOutput
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ def forward(
else:
hidden_states = upsampler(hidden_states)

class CrossAttnUpBlock2D(
diffusers_unet_2d_blocks.CrossAttnUpBlock2D
):
class CrossAttnUpBlock2D(diffusers_unet_2d_blocks.CrossAttnUpBlock2D):
def forward(
self,
hidden_states: torch.FloatTensor,
Expand Down Expand Up @@ -185,9 +183,7 @@ def forward(

return hidden_states

class CrossAttnUpBlock2D(
diffusers_unet_2d_blocks.CrossAttnUpBlock2D
):
class CrossAttnUpBlock2D(diffusers_unet_2d_blocks.CrossAttnUpBlock2D):
def forward(
self,
hidden_states: torch.FloatTensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@
UNet2DConditionOutput = (
transformed_diffusers.models.unet_2d_condition.UNet2DConditionOutput
)
proxy_UNet2DConditionModel = transformed_diffusers.models.unet_2d_condition.UNet2DConditionModel
proxy_UNet2DConditionModel = (
transformed_diffusers.models.unet_2d_condition.UNet2DConditionModel
)
else:
UNet2DConditionOutput = (
transformed_diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput
)
proxy_UNet2DConditionModel = transformed_diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
proxy_UNet2DConditionModel = (
transformed_diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
)

try:
USE_PEFT_BACKEND = transformed_diffusers.utils.USE_PEFT_BACKEND
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import oneflow as flow # usort: skip
import functools

from oneflow.framework.args_tree import ArgsTree

from onediff.utils import logger
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@ def mem_get_info(dev):

@staticmethod
def _scaled_dot_product_attention_math(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None,
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=None,
):
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale

Expand Down Expand Up @@ -51,7 +57,13 @@ def _scaled_dot_product_attention_math(

@staticmethod
def scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None,
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=None,
):
"""Scaled Dot-Product Attention
Args:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools


def patch_input_adapter(in_args, in_kwargs):
return in_args, in_kwargs

Expand Down