Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ray] launch multiple GPU with ray #396

Merged
merged 10 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions examples/ray/ray_flux_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import time
import os
import torch
import torch.distributed
from transformers import T5EncoderModel
from xfuser import xFuserArgs
from xfuser.ray.pipeline.pipeline_utils import RayDiffusionPipeline
from xfuser.config import FlexibleArgumentParser
from xfuser.model_executor.pipelines import xFuserPixArtAlphaPipeline, xFuserPixArtSigmaPipeline, xFuserStableDiffusion3Pipeline, xFuserHunyuanDiTPipeline, xFuserFluxPipeline

def main():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
parser = FlexibleArgumentParser(description="xFuser Arguments")
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()
engine_config.runtime_config.dtype = torch.bfloat16
model_name = engine_config.model_config.model.split("/")[-1]
PipelineClass = xFuserFluxPipeline
text_encoder_2 = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder_2", torch_dtype=torch.bfloat16)
if args.use_fp8_t5_encoder:
from optimum.quanto import freeze, qfloat8, quantize
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)

pipe = RayDiffusionPipeline.from_pretrained(
PipelineClass=PipelineClass,
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.bfloat16,
text_encoder_2=text_encoder_2,
)
pipe.prepare_run(input_config)

start_time = time.time()
output = pipe(
height=input_config.height,
width=input_config.width,
prompt=input_config.prompt,
num_inference_steps=input_config.num_inference_steps,
output_type=input_config.output_type,
max_sequence_length=256,
guidance_scale=0.0,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
)
end_time = time.time()
elapsed_time = end_time - start_time

print(f"elapsed time:{elapsed_time}")
if not os.path.exists("results"):
os.mkdir("results")
# output is a list of results from each worker, we take the last one
for i, image in enumerate(output[-1].images):
image.save(
f"./results/{model_name}_result_{i}.png"
)
print(
f"image {i} saved to ./results/{model_name}_result_{i}.png"
)


if __name__ == "__main__":
main()
68 changes: 68 additions & 0 deletions examples/ray/ray_run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
set -x
# If using a Ray cluster across multiple machines, you need to manually start a Ray cluster like this:
# ray start --head --port=6379 for master node
# ray start --address='192.168.1.1:6379' for worker node
# otherwise, it is not necessary. (for single node)

export PYTHONPATH=$PWD:$PYTHONPATH

# Select the model type
export MODEL_TYPE="Flux"
# Configuration for different model types
# script, model_id, inference_step
declare -A MODEL_CONFIGS=(
["Sd3"]="ray_sd3_example.py /cfs/dit/stable-diffusion-3-medium-diffusers 20"
["Flux"]="ray_flux_example.py /cfs/dit/FLUX.1-dev 28"
)

if [[ -v MODEL_CONFIGS[$MODEL_TYPE] ]]; then
IFS=' ' read -r SCRIPT MODEL_ID INFERENCE_STEP <<< "${MODEL_CONFIGS[$MODEL_TYPE]}"
export SCRIPT MODEL_ID INFERENCE_STEP
else
echo "Invalid MODEL_TYPE: $MODEL_TYPE"
exit 1
fi

mkdir -p ./results

# task args
TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning"


N_GPUS=2
PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1"

# CFG_ARGS="--use_cfg_parallel"

# By default, num_pipeline_patch = pipefusion_degree, and you can tune this parameter to achieve optimal performance.
# PIPEFUSION_ARGS="--num_pipeline_patch 8 "

# For high-resolution images, we use the latent output type to avoid runing the vae module. Used for measuring speed.
# OUTPUT_ARGS="--output_type latent"

# PARALLLEL_VAE="--use_parallel_vae"

# Another compile option is `--use_onediff` which will use onediff's compiler.
# COMPILE_FLAG="--use_torch_compile"


# Use this flag to quantize the T5 text encoder, which could reduce the memory usage and have no effect on the result quality.
# QUANTIZE_FLAG="--use_fp8_t5_encoder"

export CUDA_VISIBLE_DEVICES=0,1

python ./examples/ray/$SCRIPT \
--model $MODEL_ID \
$PARALLEL_ARGS \
$TASK_ARGS \
$PIPEFUSION_ARGS \
$OUTPUT_ARGS \
--num_inference_steps $INFERENCE_STEP \
--warmup_steps 1 \
--prompt "brown dog laying on the ground with a metal bowl in front of him." \
--use_ray \
--ray_world_size $N_GPUS \
$CFG_ARGS \
$PARALLLEL_VAE \
$COMPILE_FLAG \
$QUANTIZE_FLAG \
77 changes: 77 additions & 0 deletions examples/ray/ray_sd3_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import time
import os
import torch
import torch.distributed
from transformers import T5EncoderModel
from xfuser import xFuserArgs
from xfuser.ray.pipeline.pipeline_utils import RayDiffusionPipeline
from xfuser.config import FlexibleArgumentParser
from xfuser.model_executor.pipelines import xFuserPixArtAlphaPipeline, xFuserPixArtSigmaPipeline, xFuserStableDiffusion3Pipeline, xFuserHunyuanDiTPipeline, xFuserFluxPipeline
import time
import os
import torch
import torch.distributed
from transformers import T5EncoderModel
from xfuser import xFuserStableDiffusion3Pipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
get_world_group,
is_dp_last_group,
get_data_parallel_rank,
get_runtime_state,
)
from xfuser.core.distributed.parallel_state import get_data_parallel_world_size


def main():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
parser = FlexibleArgumentParser(description="xFuser Arguments")
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()
model_name = engine_config.model_config.model.split("/")[-1]
PipelineClass = xFuserStableDiffusion3Pipeline
text_encoder_3 = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder_3", torch_dtype=torch.float16)
if args.use_fp8_t5_encoder:
from optimum.quanto import freeze, qfloat8, quantize
print(f"rank {local_rank} quantizing text encoder 2")
quantize(text_encoder_3, weights=qfloat8)
freeze(text_encoder_3)

pipe = RayDiffusionPipeline.from_pretrained(
PipelineClass=PipelineClass,
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.float16,
text_encoder_3=text_encoder_3,
)
pipe.prepare_run(input_config)

torch.cuda.reset_peak_memory_stats()
start_time = time.time()
output = pipe(
height=input_config.height,
width=input_config.width,
prompt=input_config.prompt,
num_inference_steps=input_config.num_inference_steps,
output_type=input_config.output_type,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"elapsed time:{elapsed_time}")
if not os.path.exists("results"):
os.mkdir("results")
# output is a list of results from each worker, we take the last one
for i, image in enumerate(output[-1].images):
image.save(
f"./results/{model_name}_result_{i}.png"
)
print(
f"image {i} saved to ./results/{model_name}_result_{i}.png"
)


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def get_cuda_version():
"imageio",
"imageio-ffmpeg",
"optimum-quanto",
"flash_attn>=2.6.3" # flash_attn>=2.7.0 with torch>=2.4.0 wraps ops with torch.ops
"flash_attn>=2.6.3",
"ray"
],
extras_require={
"diffusers": [
Expand Down
25 changes: 24 additions & 1 deletion xfuser/config/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ class xFuserArgs:
# tensor parallel
tensor_parallel_degree: int = 1
split_scheme: Optional[str] = "row"
# ray arguments
use_ray: bool = False
ray_world_size: int = 1
# pipefusion parallel
pipefusion_parallel_degree: int = 1
num_pipeline_patch: Optional[int] = None
Expand Down Expand Up @@ -151,6 +154,17 @@ def add_cli_args(parser: FlexibleArgumentParser):

# Parallel arguments
parallel_group = parser.add_argument_group("Parallel Processing Options")
runtime_group.add_argument(
"--use_ray",
action="store_true",
help="Enable ray to run inference in multi-card",
)
parallel_group.add_argument(
"--ray_world_size",
type=int,
default=1,
help="The number of ray workers (world_size for ray)",
)
parallel_group.add_argument(
"--use_cfg_parallel",
action="store_true",
Expand Down Expand Up @@ -322,11 +336,15 @@ def from_cli_args(cls, args: argparse.Namespace):
def create_config(
self,
) -> Tuple[EngineConfig, InputConfig]:
if not torch.distributed.is_initialized():
if not self.use_ray and not torch.distributed.is_initialized():
logger.warning(
"Distributed environment is not initialized. " "Initializing..."
)
init_distributed_environment()
if self.use_ray:
self.world_size = self.ray_world_size
else:
self.world_size = torch.distributed.get_world_size()

model_config = ModelConfig(
model=self.model,
Expand All @@ -348,20 +366,25 @@ def create_config(
dp_config=DataParallelConfig(
dp_degree=self.data_parallel_degree,
use_cfg_parallel=self.use_cfg_parallel,
world_size=self.world_size,
),
sp_config=SequenceParallelConfig(
ulysses_degree=self.ulysses_degree,
ring_degree=self.ring_degree,
world_size=self.world_size,
),
tp_config=TensorParallelConfig(
tp_degree=self.tensor_parallel_degree,
split_scheme=self.split_scheme,
world_size=self.world_size,
),
pp_config=PipeFusionParallelConfig(
pp_degree=self.pipefusion_parallel_degree,
num_pipeline_patch=self.num_pipeline_patch,
attn_layer_num_for_pp=self.attn_layer_num_for_pp,
world_size=self.world_size,
),
world_size=self.world_size,
)

fast_attn_config = FastAttnConfig(
Expand Down
20 changes: 13 additions & 7 deletions xfuser/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __post_init__(self):
class DataParallelConfig:
dp_degree: int = 1
use_cfg_parallel: bool = False
world_size: int = 1

def __post_init__(self):
assert self.dp_degree >= 1, "dp_degree must greater than or equal to 1"
Expand All @@ -95,19 +96,20 @@ def __post_init__(self):
self.cfg_degree = 2
else:
self.cfg_degree = 1
assert self.dp_degree * self.cfg_degree <= dist.get_world_size(), (
assert self.dp_degree * self.cfg_degree <= self.world_size, (
"dp_degree * cfg_degree must be less than or equal to "
"world_size because of classifier free guidance"
)
assert (
dist.get_world_size() % (self.dp_degree * self.cfg_degree) == 0
self.world_size % (self.dp_degree * self.cfg_degree) == 0
), "world_size must be divisible by dp_degree * cfg_degree"


@dataclass
class SequenceParallelConfig:
ulysses_degree: Optional[int] = None
ring_degree: Optional[int] = None
world_size: int = 1

def __post_init__(self):
if self.ulysses_degree is None:
Expand Down Expand Up @@ -138,11 +140,12 @@ def __post_init__(self):
class TensorParallelConfig:
tp_degree: int = 1
split_scheme: Optional[str] = "row"
world_size: int = 1

def __post_init__(self):
assert self.tp_degree >= 1, "tp_degree must greater than 1"
assert (
self.tp_degree <= dist.get_world_size()
self.tp_degree <= self.world_size
), "tp_degree must be less than or equal to world_size"


Expand All @@ -151,13 +154,14 @@ class PipeFusionParallelConfig:
pp_degree: int = 1
num_pipeline_patch: Optional[int] = None
attn_layer_num_for_pp: Optional[List[int]] = (None,)
world_size: int = 1

def __post_init__(self):
assert (
self.pp_degree is not None and self.pp_degree >= 1
), "pipefusion_degree must be set and greater than 1 to use pipefusion"
assert (
self.pp_degree <= dist.get_world_size()
self.pp_degree <= self.world_size
), "pipefusion_degree must be less than or equal to world_size"
if self.num_pipeline_patch is None:
self.num_pipeline_patch = self.pp_degree
Expand Down Expand Up @@ -188,6 +192,8 @@ class ParallelConfig:
sp_config: SequenceParallelConfig
pp_config: PipeFusionParallelConfig
tp_config: TensorParallelConfig
world_size: int = 1 # FIXME: remove this
worker_cls: str = "xfuser.ray.worker.worker.Worker"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need distributed_executor_backend and worker_cls?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need distributed_executor_backend, but we need worker_cls for ray to initial worker by its class name

def init_worker(self, *args, **kwargs):
      worker_class = resolve_obj_by_qualname(
          self.worker_cls)
      self.worker = worker_class(*args, **kwargs)
      assert self.worker is not None


def __post_init__(self):
assert self.tp_config is not None, "tp_config must be set"
Expand All @@ -201,10 +207,10 @@ def __post_init__(self):
* self.tp_config.tp_degree
* self.pp_config.pp_degree
)
world_size = dist.get_world_size()
world_size = self.world_size
assert parallel_world_size == world_size, (
f"parallel_world_size {parallel_world_size} "
f"must be equal to world_size {dist.get_world_size()}"
f"must be equal to world_size {self.world_size}"
)
assert (
world_size % (self.dp_config.dp_degree * self.dp_config.cfg_degree) == 0
Expand Down Expand Up @@ -236,7 +242,7 @@ class EngineConfig:
fast_attn_config: FastAttnConfig

def __post_init__(self):
world_size = dist.get_world_size()
world_size = self.parallel_config.world_size
if self.fast_attn_config.use_fast_attn:
assert self.parallel_config.dp_degree == world_size, f"world_size must be equal to dp_degree when using DiTFastAttn"

Expand Down
Empty file added xfuser/ray/pipeline/__init__.py
Empty file.
Loading
Loading