Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add kolors compile #1007

Merged
merged 19 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions onediff_diffusers_extensions/examples/kolors/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Run kolors with onediff (Beta Release)


## Environment setup

### Set up onediff
https://github.com/siliconflow/onediff?tab=readme-ov-file#installation

### Set up compiler backend
Support two backends: oneflow and nexfort.

https://github.com/siliconflow/onediff?tab=readme-ov-file#install-a-compiler-backend


### Set up diffusers

```
# Ensure diffusers include the kolors pipeline.
pip install git+https://github.com/huggingface/diffusers.git
```

### Set up kolors

HF model: https://huggingface.co/Kwai-Kolors/Kolors-diffusers

HF pipeline: https://github.com/huggingface/diffusers/blob/main/docs/source/en/api/pipelines/kolors.md


## Run

### Run 1024*1024 without compile (the original pytorch HF diffusers baseline)
```
python3 onediff_diffusers_extensions/examples/kolors/text_to_image_kolors.py \
--saved-image kolors.png
```

### Run 1024*1024 with compile [oneflow backend]

```
python3 onediff_diffusers_extensions/examples/kolors/text_to_image_kolors.py \
--compiler oneflow \
--saved-image kolors_oneflow_compile.png
```

### Run 1024*1024 with compile [nexfort backend]

```
python3 onediff_diffusers_extensions/examples/kolors/text_to_image_kolors.py \
--compiler nexfort \
--compiler-config '{"mode": "max-optimize:max-autotune:low-precision", "memory_format": "channels_last"}' \
--saved-image kolors_nexfort_compile.png
```

## Performance comparation

**Testing on an NVIDIA RTX 4090 GPU, using a resolution of 1024x1024 and 50 steps:**

Data update date: 2024-07-23

| Framework | Iteration Speed (it/s) | E2E Time (seconds) | Max Memory Used (GiB) | Warmup time (seconds) <sup>1</sup> | Warmup with Cache time (seconds) |
|--------------------|------------------------|--------------------|-----------------------|-------------|------------------------|
| PyTorch | 8.11 | 6.55 | 20.623 | 7.09 | - |
| OneDiff (OneFlow) | 15.16 | 3.86 | 20.622 | 39.61 | 7.47 |
| OneDiff (NexFort) | 14.68 | 3.71 | 21.623 | 190.14 | 50.46 |
lixiang007666 marked this conversation as resolved.
Show resolved Hide resolved

<sup>1</sup> OneDiff Warmup with Compilation time is tested on AMD EPYC 7543 32-Core Processor.

**Testing on NVIDIA A100-PCIE-40GB:**

Data update date: 2024-07-23

| Framework | Iteration Speed (it/s) | E2E Time (seconds) | Max Memory Used (GiB) | Warmup time (seconds) <sup>2</sup> | Warmup with Cache time (seconds) |
|--------------------|------------------------|--------------------|-----------------------|-------------|------------------------|
| PyTorch | 8.36 | 6.34 | 20.622 | 7.88 | - |
| OneDiff (OneFlow) | 11.54 | 4.69 | 20.627 | 50.02 | 12.82 |
| OneDiff (NexFort) | 10.53 | 5.02 | 21.622 | 269.89 | 73.31 |

<sup>2</sup> OneDiff Warmup with Compilation time is tested on Intel(R) Xeon(R) Gold 6348 CPU @ 2.60GHz.

**Testing on NVIDIA A100-SXM4-80GB:**

Data update date: 2024-07-23

| Framework | Iteration Speed (it/s) | E2E Time (seconds) | Max Memory Used (GiB) | Warmup time (seconds) <sup>3</sup> | Warmup with Cache time (seconds) |
|--------------------|------------------------|--------------------|-----------------------|-------------|------------------------|
| PyTorch | 9.88 | 5.38 | 20.622 | 6.61 | - |
| OneDiff (OneFlow) | 13.70 | 3.96 | 20.627 | 52.93 | 11.79 |
| OneDiff (NexFort) | 13.20 | 4.04 | 21.622 | 150.78 | 58.07 |

<sup>3</sup> OneDiff Warmup with Compilation time is tested on Intel(R) Xeon(R) Platinum 8468.

## Dynamic shape.

Run:

```
# oneflow
python3 onediff_diffusers_extensions/examples/kolors/text_to_image_kolors.py \
--compiler oneflow \
--run_multiple_resolutions 1 \
--saved-image kolors_oneflow_compile.png
```

or

```
# nexfort
python3 onediff_diffusers_extensions/examples/kolors/text_to_image_kolors.py \
--height 512 \
--width 768 \
--compiler nexfort \
--compiler-config '{"mode": "max-optimize:max-autotune:low-precision", "memory_format": "channels_last", "dynamic": true}' \
--run_multiple_resolutions 1 \
--saved-image kolors_nexfort_compile.png
```

lixiang007666 marked this conversation as resolved.
Show resolved Hide resolved
## Quality
lixiang007666 marked this conversation as resolved.
Show resolved Hide resolved

The quality report for accelerating the kolors model with onediff is located at:
https://github.com/siliconflow/odeval/tree/main/models/kolors

204 changes: 204 additions & 0 deletions onediff_diffusers_extensions/examples/kolors/text_to_image_kolors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import argparse
import json
import time

from diffusers import DPMSolverMultistepScheduler, KolorsPipeline
from onediffx import compile_pipe, quantize_pipe, load_pipe, save_pipe
from onediff.infer_compiler import oneflow_compile
import torch


def parse_args():
parser = argparse.ArgumentParser(
description="Use onediif to accelerate image generation with Kolors"
)
parser.add_argument(
"--model",
type=str,
default="Kwai-Kolors/Kolors-diffusers",
help="Model path or identifier.",
)
parser.add_argument(
"--compiler",
type=str,
default="none",
help="Compiler backend to use. Options: 'none', 'nexfort', 'oneflow'",
)
parser.add_argument(
"--compiler-config", type=str, help="JSON string for compiler config."
)
parser.add_argument(
"--quantize-config", type=str, help="JSON string for quantization config."
)
parser.add_argument(
"--prompt",
type=str,
default='一张瓢虫的照片,微距,变焦,高质量,电影,拿着一个牌子,写着"可图"',
help="Prompt for the image generation.",
)
parser.add_argument(
"--height", type=int, default=1024, help="Height of the generated image."
)
parser.add_argument(
"--width", type=int, default=1024, help="Width of the generated image."
)
parser.add_argument(
"--guidance_scale",
type=float,
default=6.5,
help="The scale factor for the guidance.",
)
parser.add_argument(
"--num-inference-steps", type=int, default=50, help="Number of inference steps."
)
parser.add_argument(
"--saved-image",
type=str,
default="./kolors.png",
help="Path to save the generated image.",
)
parser.add_argument(
"--seed", type=int, default=66, help="Seed for random number generation."
)
parser.add_argument(
"--warmup-iterations",
type=int,
default=1,
help="Number of warm-up iterations before actual inference.",
)
parser.add_argument(
"--run_multiple_resolutions",
type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
default=False,
)
parser.add_argument("--oneflow_save", action=argparse.BooleanOptionalAction)
parser.add_argument("--oneflow_load", action=argparse.BooleanOptionalAction)
return parser.parse_args()


args = parse_args()

device = torch.device("cuda")


class KolorsGenerator:
def __init__(
self, model, compiler_config=None, quantize_config=None, compiler="none"
):
self.pipe = KolorsPipeline.from_pretrained(
model, torch_dtype=torch.float16, variant="fp16"
).to(device)
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config, use_karras_sigmas=True
)

if compiler == "nexfort":
if compiler_config:
print("nexfort backend compile...")
self.pipe = self.compile_pipe(self.pipe, compiler_config)

if quantize_config:
print("nexfort backend quant...")
self.pipe = self.quantize_pipe(self.pipe, quantize_config)
elif compiler == "oneflow":
print("oneflow backend compile...")
# self.pipe.unet = self.oneflow_compile(self.pipe.unet)
self.pipe = compile_pipe(self.pipe, ignores=['text_encoder', 'vae'])
lixiang007666 marked this conversation as resolved.
Show resolved Hide resolved

def warmup(self, gen_args, warmup_iterations):
warmup_args = gen_args.copy()

warmup_args["generator"] = torch.Generator(device=device).manual_seed(0)

print("Starting warmup...")
start_time = time.time()
if args.oneflow_load:
load_pipe(self.pipe, dir="cached_pipe")

for _ in range(warmup_iterations):
self.pipe(**warmup_args)

if args.oneflow_save:
save_pipe(self.pipe, dir="cached_pipe")
end_time = time.time()
print("Warmup complete.")
print(f"Warmup time: {end_time - start_time:.2f} seconds")

def generate(self, gen_args):
gen_args["generator"] = torch.Generator(device=device).manual_seed(args.seed)

# Run the model
start_time = time.time()
images = self.pipe(**gen_args).images
end_time = time.time()

images[0].save(args.saved_image)

return images[0], end_time - start_time

def compile_pipe(self, pipe, compiler_config):
options = compiler_config
pipe = compile_pipe(
pipe, backend="nexfort", options=options, fuse_qkv_projections=True
)
return pipe

def quantize_pipe(self, pipe, quantize_config):
pipe = quantize_pipe(pipe, ignores=[], **quantize_config)
return pipe

def oneflow_compile(self, unet):
return oneflow_compile(unet)


def main():
nexfort_compiler_config = (
json.loads(args.compiler_config) if args.compiler_config else None
)
nexfort_quantize_config = (
json.loads(args.quantize_config) if args.quantize_config else None
)

kolors = KolorsGenerator(
args.model,
nexfort_compiler_config,
nexfort_quantize_config,
compiler=args.compiler,
)

gen_args = {
"prompt": args.prompt,
"num_inference_steps": args.num_inference_steps,
"height": args.height,
"width": args.width,
"guidance_scale": args.guidance_scale,
}

kolors.warmup(gen_args, args.warmup_iterations)

image, inference_time = kolors.generate(gen_args)
print(
f"Generated image saved to {args.saved_image} in {inference_time:.2f} seconds."
)
cuda_mem_after_used = torch.cuda.max_memory_allocated() / (1024**3)
print(f"Max used CUDA memory : {cuda_mem_after_used:.3f}GiB")

if args.run_multiple_resolutions:
print("Test run with multiple resolutions...")
sizes = [1024, 768, 576, 512, 256]
for h in sizes:
for w in sizes:
gen_args["height"] = h
gen_args["width"] = w
print(f"Running at resolution: {h}x{w}")
start_time = time.time()
kolors.generate(gen_args)
end_time = time.time()
print(f"Inference time: {end_time - start_time:.2f} seconds")
assert (
end_time - start_time
) < 20, "Resolution switch test took too long"


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -905,12 +905,22 @@ def forward(
LoRACompatibleLinear = transformed_diffusers.models.lora.LoRACompatibleLinear
ModelMixin = transformed_diffusers.models.modeling_utils.ModelMixin
AdaLayerNormSingle = transformed_diffusers.models.normalization.AdaLayerNormSingle
Transformer2DModelOutput = (
transformed_diffusers.models.transformer_2d.Transformer2DModelOutput
)
proxy_Transformer2DModel = (
transformed_diffusers.models.transformer_2d.Transformer2DModel
)
diffusers_0260_v = version.parse("0.26.0")
diffusers_0280_v = version.parse("0.28.0")
if diffusers_version >= diffusers_0260_v:
Transformer2DModelOutput = (
transformed_diffusers.models.transformers.transformer_2d.Transformer2DModelOutput
)
proxy_Transformer2DModel = (
transformed_diffusers.models.transformers.transformer_2d.Transformer2DModel
)
else:
Transformer2DModelOutput = (
transformed_diffusers.models.transformer_2d.Transformer2DModelOutput
)
proxy_Transformer2DModel = (
transformed_diffusers.models.transformer_2d.Transformer2DModel
)

class Transformer2DModel(proxy_Transformer2DModel):
def forward(
Expand Down Expand Up @@ -1059,6 +1069,12 @@ def forward(
)

# 2. Blocks
if diffusers_version >= diffusers_0280_v:
self.caption_projection = None
if self.caption_channels is not None:
self.caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels, hidden_size=self.inner_dim
)
if self.caption_projection is not None:
batch_size = hidden_states.shape[0]
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
Expand Down
Loading
Loading