Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support fp8 t5 encoder in examples #366

Merged
merged 3 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions examples/flux_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
import torch
import torch.distributed
from transformers import T5EncoderModel
from xfuser import xFuserFluxPipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
Expand All @@ -20,11 +21,19 @@ def main():
engine_config, input_config = engine_args.create_config()
engine_config.runtime_config.dtype = torch.bfloat16
local_rank = get_world_group().local_rank
text_encoder_2 = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder_2", torch_dtype=torch.bfloat16)

if args.use_fp8_t5_encoder:
from optimum.quanto import freeze, qfloat8, quantize
logging.info(f"rank {local_rank} quantizing text encoder 2")
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)

pipe = xFuserFluxPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.bfloat16,
text_encoder_2=text_encoder_2,
)

if args.enable_sequential_cpu_offload:
Expand Down
9 changes: 9 additions & 0 deletions examples/hunyuandit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import torch
import torch.distributed
from transformers import T5EncoderModel
from xfuser import xFuserHunyuanDiTPipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
Expand All @@ -19,10 +20,18 @@ def main():
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()
local_rank = get_world_group().local_rank
text_encoder_2 = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder_2", torch_dtype=torch.bfloat16)
if args.use_fp8_t5_encoder:
from optimum.quanto import freeze, qfloat8, quantize
print(f"rank {local_rank} quantizing text encoder 2")
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)

pipe = xFuserHunyuanDiTPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.float16,
text_encoder_2=text_encoder_2,
).to(f"cuda:{local_rank}")

parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")
Expand Down
9 changes: 9 additions & 0 deletions examples/pixartalpha_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import torch
import torch.distributed
from transformers import T5EncoderModel
from xfuser import xFuserPixArtAlphaPipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
Expand All @@ -19,10 +20,18 @@ def main():
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()
local_rank = get_world_group().local_rank
text_encoder = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder", torch_dtype=torch.float16)
if args.use_fp8_t5_encoder:
from optimum.quanto import freeze, qfloat8, quantize
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add optimum in setup.py

print(f"rank {local_rank} quantizing text encoder")
quantize(text_encoder, weights=qfloat8)
freeze(text_encoder)

pipe = xFuserPixArtAlphaPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.float16,
text_encoder=text_encoder,
).to(f"cuda:{local_rank}")
model_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")
pipe.prepare_run(input_config)
Expand Down
9 changes: 9 additions & 0 deletions examples/pixartsigma_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import torch
import torch.distributed
from transformers import T5EncoderModel
from xfuser import xFuserPixArtSigmaPipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
Expand All @@ -19,10 +20,18 @@ def main():
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()
local_rank = get_world_group().local_rank
text_encoder = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder", torch_dtype=torch.float16)
if args.use_fp8_t5_encoder:
from optimum.quanto import freeze, qfloat8, quantize
print(f"rank {local_rank} quantizing text encoder")
quantize(text_encoder, weights=qfloat8)
freeze(text_encoder)

pipe = xFuserPixArtSigmaPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.float16,
text_encoder=text_encoder,
).to(f"cuda:{local_rank}")
pipe.prepare_run(input_config)

Expand Down
6 changes: 5 additions & 1 deletion examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 2 --ring_degree 2
# COMPILE_FLAG="--use_torch_compile"


# Use this flag to quantize the T5 text encoder, which could reduce the memory usage and have no effect on the result quality.
# QUANTIZE_FLAG="--use_fp8_t5_encoder"

# export CUDA_VISIBLE_DEVICES=4,5,6,7

torchrun --nproc_per_node=$N_GPUS ./examples/$SCRIPT \
Expand All @@ -59,4 +62,5 @@ $OUTPUT_ARGS \
--prompt "brown dog laying on the ground with a metal bowl in front of him." \
$CFG_ARGS \
$PARALLLEL_VAE \
$COMPILE_FLAG
$COMPILE_FLAG \
$QUANTIZE_FLAG \
9 changes: 9 additions & 0 deletions examples/sd3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import torch
import torch.distributed
from transformers import T5EncoderModel
from xfuser import xFuserStableDiffusion3Pipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
Expand All @@ -19,10 +20,18 @@ def main():
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()
local_rank = get_world_group().local_rank
text_encoder_3 = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder_3", torch_dtype=torch.float16)
if args.use_fp8_t5_encoder:
from optimum.quanto import freeze, qfloat8, quantize
print(f"rank {local_rank} quantizing text encoder 2")
quantize(text_encoder_3, weights=qfloat8)
freeze(text_encoder_3)

pipe = xFuserStableDiffusion3Pipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.float16,
text_encoder_3=text_encoder_3,
).to(f"cuda:{local_rank}")

parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def get_cuda_version():
"opencv-python",
"imageio",
"imageio-ffmpeg",
"optimum-quanto"
],
extras_require={
"flash_attn": [
Expand Down
7 changes: 7 additions & 0 deletions xfuser/config/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class xFuserArgs:
window_size: int = 64
coco_path: Optional[str] = None
use_cache: bool = False
use_fp8_t5_encoder: bool = False

@staticmethod
def add_cli_args(parser: FlexibleArgumentParser):
Expand Down Expand Up @@ -265,6 +266,11 @@ def add_cli_args(parser: FlexibleArgumentParser):
action="store_true",
help="Making VAE decode a tile at a time to save GPU memory.",
)
runtime_group.add_argument(
"--use_fp8_t5_encoder",
action="store_true",
help="Quantize the T5 text encoder.",
)

# DiTFastAttn arguments
fast_attn_group = parser.add_argument_group("DiTFastAttn Options")
Expand Down Expand Up @@ -335,6 +341,7 @@ def create_config(
use_torch_compile=self.use_torch_compile,
use_onediff=self.use_onediff,
# use_profiler=self.use_profiler,
use_fp8_t5_encoder=self.use_fp8_t5_encoder,
)

parallel_config = ParallelConfig(
Expand Down
1 change: 1 addition & 0 deletions xfuser/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class RuntimeConfig:
use_profiler: bool = False
use_torch_compile: bool = False
use_onediff: bool = False
use_fp8_t5_encoder: bool = False

def __post_init__(self):
check_packages()
Expand Down