diff --git a/onediff_diffusers_extensions/examples/kolors/README.md b/onediff_diffusers_extensions/examples/kolors/README.md new file mode 100644 index 000000000..6bab17951 --- /dev/null +++ b/onediff_diffusers_extensions/examples/kolors/README.md @@ -0,0 +1,121 @@ +# Run kolors with onediff (Beta Release) + + +## Environment setup + +### Set up onediff +https://github.com/siliconflow/onediff?tab=readme-ov-file#installation + +### Set up compiler backend +Support two backends: oneflow and nexfort. + +https://github.com/siliconflow/onediff?tab=readme-ov-file#install-a-compiler-backend + + +### Set up diffusers + +``` +# Ensure diffusers include the kolors pipeline. +pip install git+https://github.com/huggingface/diffusers.git +``` + +### Set up kolors + +HF model: https://huggingface.co/Kwai-Kolors/Kolors-diffusers + +HF pipeline: https://github.com/huggingface/diffusers/blob/main/docs/source/en/api/pipelines/kolors.md + + +## Run + +### Run 1024*1024 without compile (the original pytorch HF diffusers baseline) +``` +python3 onediff_diffusers_extensions/examples/kolors/text_to_image_kolors.py \ +--saved-image kolors.png +``` + +### Run 1024*1024 with compile [oneflow backend] + +``` +python3 onediff_diffusers_extensions/examples/kolors/text_to_image_kolors.py \ +--compiler oneflow \ +--saved-image kolors_oneflow_compile.png +``` + +### Run 1024*1024 with compile [nexfort backend] + +``` +python3 onediff_diffusers_extensions/examples/kolors/text_to_image_kolors.py \ +--compiler nexfort \ +--compiler-config '{"mode": "max-optimize:max-autotune:low-precision", "memory_format": "channels_last"}' \ +--saved-image kolors_nexfort_compile.png +``` + +## Performance comparation + +**Testing on an NVIDIA RTX 4090 GPU, using a resolution of 1024x1024 and 50 steps:** + +Data update date: 2024-07-23 + +| Framework | Iteration Speed (it/s) | E2E Time (seconds) | Max Memory Used (GiB) | Warmup time (seconds) 1 | Warmup with Cache time (seconds) | +|--------------------|------------------------|--------------------|-----------------------|-------------|------------------------| +| PyTorch | 8.11 | 6.55 | 20.623 | 7.09 | - | +| OneDiff (OneFlow) | 15.16 (+86.9%) | 3.86 (-41.1%) | 20.622 | 39.61 | 7.47 | +| OneDiff (NexFort) | 14.68 (+81.0%) | 3.71 (-43.4%) | 21.623 | 190.14 | 50.46 | + + 1 OneDiff Warmup with Compilation time is tested on AMD EPYC 7543 32-Core Processor. + +**Testing on NVIDIA A100-PCIE-40GB:** + +Data update date: 2024-07-23 + +| Framework | Iteration Speed (it/s) | E2E Time (seconds) | Max Memory Used (GiB) | Warmup time (seconds) 2 | Warmup with Cache time (seconds) | +|--------------------|------------------------|--------------------|-----------------------|-------------|------------------------| +| PyTorch | 8.36 | 6.34 | 20.622 | 7.88 | - | +| OneDiff (OneFlow) | 11.54 (+38.0%) | 4.69 (-26.0%) | 20.627 | 50.02 | 12.82 | +| OneDiff (NexFort) | 10.53 (+26.0%) | 5.02 (-20.8%) | 21.622 | 269.89 | 73.31 | + + 2 OneDiff Warmup with Compilation time is tested on Intel(R) Xeon(R) Gold 6348 CPU @ 2.60GHz. + +**Testing on NVIDIA A100-SXM4-80GB:** + +Data update date: 2024-07-23 + +| Framework | Iteration Speed (it/s) | E2E Time (seconds) | Max Memory Used (GiB) | Warmup time (seconds) 3 | Warmup with Cache time (seconds) | +|--------------------|------------------------|--------------------|-----------------------|-------------|------------------------| +| PyTorch | 9.88 | 5.38 | 20.622 | 6.61 | - | +| OneDiff (OneFlow) | 13.70 (+38.7%) | 3.96 (-26.4%) | 20.627 | 52.93 | 11.79 | +| OneDiff (NexFort) | 13.20 (+33.6%) | 4.04 (-24.9%) | 21.622 | 150.78 | 58.07 | + + 3 OneDiff Warmup with Compilation time is tested on Intel(R) Xeon(R) Platinum 8468. + +## Dynamic shape. + +Run: + +``` +# oneflow +python3 onediff_diffusers_extensions/examples/kolors/text_to_image_kolors.py \ +--compiler oneflow \ +--run_multiple_resolutions 1 \ +--saved-image kolors_oneflow_compile.png +``` + +or + +``` +# nexfort +python3 onediff_diffusers_extensions/examples/kolors/text_to_image_kolors.py \ +--height 512 \ +--width 768 \ +--compiler nexfort \ +--compiler-config '{"mode": "max-optimize:max-autotune:low-precision", "memory_format": "channels_last", "dynamic": true}' \ +--run_multiple_resolutions 1 \ +--saved-image kolors_nexfort_compile.png +``` + +## Quality + +The quality report for accelerating the kolors model with onediff is located at: +https://github.com/siliconflow/odeval/tree/main/models/kolors + diff --git a/onediff_diffusers_extensions/examples/kolors/text_to_image_kolors.py b/onediff_diffusers_extensions/examples/kolors/text_to_image_kolors.py new file mode 100644 index 000000000..051e67443 --- /dev/null +++ b/onediff_diffusers_extensions/examples/kolors/text_to_image_kolors.py @@ -0,0 +1,204 @@ +import argparse +import json +import time + +from diffusers import DPMSolverMultistepScheduler, KolorsPipeline +from onediffx import compile_pipe, quantize_pipe, load_pipe, save_pipe +from onediff.infer_compiler import oneflow_compile +import torch + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Use onediif to accelerate image generation with Kolors" + ) + parser.add_argument( + "--model", + type=str, + default="Kwai-Kolors/Kolors-diffusers", + help="Model path or identifier.", + ) + parser.add_argument( + "--compiler", + type=str, + default="none", + help="Compiler backend to use. Options: 'none', 'nexfort', 'oneflow'", + ) + parser.add_argument( + "--compiler-config", type=str, help="JSON string for compiler config." + ) + parser.add_argument( + "--quantize-config", type=str, help="JSON string for quantization config." + ) + parser.add_argument( + "--prompt", + type=str, + default='一张瓢虫的照片,微距,变焦,高质量,电影,拿着一个牌子,写着"可图"', + help="Prompt for the image generation.", + ) + parser.add_argument( + "--height", type=int, default=1024, help="Height of the generated image." + ) + parser.add_argument( + "--width", type=int, default=1024, help="Width of the generated image." + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=6.5, + help="The scale factor for the guidance.", + ) + parser.add_argument( + "--num-inference-steps", type=int, default=50, help="Number of inference steps." + ) + parser.add_argument( + "--saved-image", + type=str, + default="./kolors.png", + help="Path to save the generated image.", + ) + parser.add_argument( + "--seed", type=int, default=66, help="Seed for random number generation." + ) + parser.add_argument( + "--warmup-iterations", + type=int, + default=1, + help="Number of warm-up iterations before actual inference.", + ) + parser.add_argument( + "--run_multiple_resolutions", + type=(lambda x: str(x).lower() in ["true", "1", "yes"]), + default=False, + ) + parser.add_argument("--oneflow_save", action=argparse.BooleanOptionalAction) + parser.add_argument("--oneflow_load", action=argparse.BooleanOptionalAction) + return parser.parse_args() + + +args = parse_args() + +device = torch.device("cuda") + + +class KolorsGenerator: + def __init__( + self, model, compiler_config=None, quantize_config=None, compiler="none" + ): + self.pipe = KolorsPipeline.from_pretrained( + model, torch_dtype=torch.float16, variant="fp16" + ).to(device) + self.pipe.scheduler = DPMSolverMultistepScheduler.from_config( + self.pipe.scheduler.config, use_karras_sigmas=True + ) + + if compiler == "nexfort": + if compiler_config: + print("nexfort backend compile...") + self.pipe = self.compile_pipe(self.pipe, compiler_config) + + if quantize_config: + print("nexfort backend quant...") + self.pipe = self.quantize_pipe(self.pipe, quantize_config) + elif compiler == "oneflow": + print("oneflow backend compile...") + # self.pipe.unet = self.oneflow_compile(self.pipe.unet) + self.pipe = compile_pipe(self.pipe, ignores=['text_encoder', 'vae']) + + def warmup(self, gen_args, warmup_iterations): + warmup_args = gen_args.copy() + + warmup_args["generator"] = torch.Generator(device=device).manual_seed(0) + + print("Starting warmup...") + start_time = time.time() + if args.oneflow_load: + load_pipe(self.pipe, dir="cached_pipe") + + for _ in range(warmup_iterations): + self.pipe(**warmup_args) + + if args.oneflow_save: + save_pipe(self.pipe, dir="cached_pipe") + end_time = time.time() + print("Warmup complete.") + print(f"Warmup time: {end_time - start_time:.2f} seconds") + + def generate(self, gen_args): + gen_args["generator"] = torch.Generator(device=device).manual_seed(args.seed) + + # Run the model + start_time = time.time() + images = self.pipe(**gen_args).images + end_time = time.time() + + images[0].save(args.saved_image) + + return images[0], end_time - start_time + + def compile_pipe(self, pipe, compiler_config): + options = compiler_config + pipe = compile_pipe( + pipe, backend="nexfort", options=options, fuse_qkv_projections=True + ) + return pipe + + def quantize_pipe(self, pipe, quantize_config): + pipe = quantize_pipe(pipe, ignores=[], **quantize_config) + return pipe + + def oneflow_compile(self, unet): + return oneflow_compile(unet) + + +def main(): + nexfort_compiler_config = ( + json.loads(args.compiler_config) if args.compiler_config else None + ) + nexfort_quantize_config = ( + json.loads(args.quantize_config) if args.quantize_config else None + ) + + kolors = KolorsGenerator( + args.model, + nexfort_compiler_config, + nexfort_quantize_config, + compiler=args.compiler, + ) + + gen_args = { + "prompt": args.prompt, + "num_inference_steps": args.num_inference_steps, + "height": args.height, + "width": args.width, + "guidance_scale": args.guidance_scale, + } + + kolors.warmup(gen_args, args.warmup_iterations) + + image, inference_time = kolors.generate(gen_args) + print( + f"Generated image saved to {args.saved_image} in {inference_time:.2f} seconds." + ) + cuda_mem_after_used = torch.cuda.max_memory_allocated() / (1024**3) + print(f"Max used CUDA memory : {cuda_mem_after_used:.3f}GiB") + + if args.run_multiple_resolutions: + print("Test run with multiple resolutions...") + sizes = [1024, 768, 576, 512, 256] + for h in sizes: + for w in sizes: + gen_args["height"] = h + gen_args["width"] = w + print(f"Running at resolution: {h}x{w}") + start_time = time.time() + kolors.generate(gen_args) + end_time = time.time() + print(f"Inference time: {end_time - start_time:.2f} seconds") + assert ( + end_time - start_time + ) < 20, "Resolution switch test took too long" + + +if __name__ == "__main__": + main() diff --git a/src/infer_compiler_registry/register_diffusers/transformer_2d_oflow.py b/src/infer_compiler_registry/register_diffusers/transformer_2d_oflow.py index 07512fc12..cc91c30b3 100644 --- a/src/infer_compiler_registry/register_diffusers/transformer_2d_oflow.py +++ b/src/infer_compiler_registry/register_diffusers/transformer_2d_oflow.py @@ -905,12 +905,22 @@ def forward( LoRACompatibleLinear = transformed_diffusers.models.lora.LoRACompatibleLinear ModelMixin = transformed_diffusers.models.modeling_utils.ModelMixin AdaLayerNormSingle = transformed_diffusers.models.normalization.AdaLayerNormSingle - Transformer2DModelOutput = ( - transformed_diffusers.models.transformer_2d.Transformer2DModelOutput - ) - proxy_Transformer2DModel = ( - transformed_diffusers.models.transformer_2d.Transformer2DModel - ) + diffusers_0260_v = version.parse("0.26.0") + diffusers_0280_v = version.parse("0.28.0") + if diffusers_version >= diffusers_0260_v: + Transformer2DModelOutput = ( + transformed_diffusers.models.transformers.transformer_2d.Transformer2DModelOutput + ) + proxy_Transformer2DModel = ( + transformed_diffusers.models.transformers.transformer_2d.Transformer2DModel + ) + else: + Transformer2DModelOutput = ( + transformed_diffusers.models.transformer_2d.Transformer2DModelOutput + ) + proxy_Transformer2DModel = ( + transformed_diffusers.models.transformer_2d.Transformer2DModel + ) class Transformer2DModel(proxy_Transformer2DModel): def forward( @@ -1059,6 +1069,12 @@ def forward( ) # 2. Blocks + if diffusers_version >= diffusers_0280_v: + self.caption_projection = None + if self.caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, hidden_size=self.inner_dim + ) if self.caption_projection is not None: batch_size = hidden_states.shape[0] encoder_hidden_states = self.caption_projection(encoder_hidden_states) diff --git a/src/infer_compiler_registry/register_diffusers/unet_2d_blocks_oflow.py b/src/infer_compiler_registry/register_diffusers/unet_2d_blocks_oflow.py index ffc8368f2..bae112a59 100644 --- a/src/infer_compiler_registry/register_diffusers/unet_2d_blocks_oflow.py +++ b/src/infer_compiler_registry/register_diffusers/unet_2d_blocks_oflow.py @@ -6,6 +6,7 @@ from packaging import version diffusers_0210_v = version.parse("0.21.0") +diffusers_0260_v = version.parse("0.26.0") diffusers_version = version.parse(importlib.metadata.version("diffusers")) transformed_diffusers = transform_mgr.transform_package("diffusers") @@ -150,7 +151,16 @@ def custom_forward(*inputs): else: - class AttnUpBlock2D(transformed_diffusers.models.unet_2d_blocks.AttnUpBlock2D): + if diffusers_version >= diffusers_0260_v: + AttnUpBlock2DBase = transformed_diffusers.models.unets.unet_2d_blocks.AttnUpBlock2D + CrossAttnUpBlock2DBase = transformed_diffusers.models.unets.unet_2d_blocks.CrossAttnUpBlock2D + UpBlock2DBase = transformed_diffusers.models.unets.unet_2d_blocks.UpBlock2D + else: + AttnUpBlock2DBase = transformed_diffusers.models.unet_2d_blocks.AttnUpBlock2D + CrossAttnUpBlock2DBase = transformed_diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D + UpBlock2DBase = transformed_diffusers.models.unet_2d_blocks.UpBlock2D + + class AttnUpBlock2D(AttnUpBlock2DBase): def forward( self, hidden_states: torch.FloatTensor, @@ -179,9 +189,7 @@ def forward( return hidden_states - class CrossAttnUpBlock2D( - transformed_diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D - ): + class CrossAttnUpBlock2D(CrossAttnUpBlock2DBase): def forward( self, hidden_states: torch.FloatTensor, @@ -277,7 +285,7 @@ def custom_forward(*inputs): return hidden_states - class UpBlock2D(transformed_diffusers.models.unet_2d_blocks.UpBlock2D): + class UpBlock2D(UpBlock2DBase): def forward( self, hidden_states: torch.FloatTensor, diff --git a/src/infer_compiler_registry/register_diffusers/unet_2d_condition_oflow.py b/src/infer_compiler_registry/register_diffusers/unet_2d_condition_oflow.py index 8154a3147..63afd961b 100644 --- a/src/infer_compiler_registry/register_diffusers/unet_2d_condition_oflow.py +++ b/src/infer_compiler_registry/register_diffusers/unet_2d_condition_oflow.py @@ -9,9 +9,15 @@ diffusers_version = version.parse(importlib.metadata.version("diffusers")) transformed_diffusers = transform_mgr.transform_package("diffusers") -UNet2DConditionOutput = ( - transformed_diffusers.models.unet_2d_condition.UNet2DConditionOutput -) +diffusers_0260_v = version.parse("0.26.0") +if diffusers_version >= diffusers_0260_v: + UNet2DConditionOutput = ( + transformed_diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput + ) +else: + UNet2DConditionOutput = ( + transformed_diffusers.models.unet_2d_condition.UNet2DConditionOutput + ) try: USE_PEFT_BACKEND = transformed_diffusers.utils.USE_PEFT_BACKEND @@ -21,9 +27,12 @@ USE_PEFT_BACKEND = False -class UNet2DConditionModel( - transformed_diffusers.models.unet_2d_condition.UNet2DConditionModel -): +if diffusers_version >= diffusers_0260_v: + UNet2DConditionModelBase = transformed_diffusers.models.unets.unet_2d_condition.UNet2DConditionModel +else: + UNet2DConditionModelBase = transformed_diffusers.models.unet_2d_condition.UNet2DConditionModel + +class UNet2DConditionModel(UNet2DConditionModelBase): def forward( self, sample: torch.FloatTensor,