Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix parallel vae #281

Merged
merged 6 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/flux_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def main():
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()
engine_config.runtime_config.dtype = torch.bfloat16
local_rank = get_world_group().local_rank

pipe = xFuserFluxPipeline.from_pretrained(
Expand All @@ -32,7 +33,7 @@ def main():
else:
pipe = pipe.to(f"cuda:{local_rank}")

pipe.prepare_run(input_config)
pipe.prepare_run(input_config, steps=1)

torch.cuda.reset_peak_memory_stats()
start_time = time.time()
Expand Down Expand Up @@ -60,7 +61,7 @@ def main():
dp_group_index = get_data_parallel_rank()
num_dp_groups = get_data_parallel_world_size()
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
if is_dp_last_group():
if pipe.is_dp_last_group():
for i, image in enumerate(output.images):
image_rank = dp_group_index * dp_batch_size + i
image_name = f"flux_result_{parallel_info}_{image_rank}_tc_{engine_args.use_torch_compile}.png"
Expand Down
2 changes: 1 addition & 1 deletion examples/hunyuandit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def main():
dp_group_index = get_data_parallel_rank()
num_dp_groups = get_data_parallel_world_size()
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
if is_dp_last_group():
if pipe.is_dp_last_group():
if not os.path.exists("results"):
os.mkdir("results")
for i, image in enumerate(output.images):
Expand Down
2 changes: 1 addition & 1 deletion examples/pixartalpha_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def main():
dp_group_index = get_data_parallel_rank()
num_dp_groups = get_data_parallel_world_size()
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
if is_dp_last_group():
if pipe.is_dp_last_group():
if not os.path.exists("results"):
os.mkdir("results")
for i, image in enumerate(output.images):
Expand Down
3 changes: 2 additions & 1 deletion examples/pixartsigma_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def main():
output_type=input_config.output_type,
use_resolution_binning=input_config.use_resolution_binning,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
clean_caption=False,
)
end_time = time.time()
elapsed_time = end_time - start_time
Expand All @@ -50,7 +51,7 @@ def main():
dp_group_index = get_data_parallel_rank()
num_dp_groups = get_data_parallel_world_size()
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
if is_dp_last_group():
if pipe.is_dp_last_group():
if not os.path.exists("results"):
os.mkdir("results")
for i, image in enumerate(output.images):
Expand Down
2 changes: 1 addition & 1 deletion examples/sd3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def main():
dp_group_index = get_data_parallel_rank()
num_dp_groups = get_data_parallel_world_size()
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
if is_dp_last_group():
if pipe.is_dp_last_group():
if not os.path.exists("results"):
os.mkdir("results")
for i, image in enumerate(output.images):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def forward(
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
elif joint_attention_kwargs and "scale" in joint_attention_kwargs:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
Expand Down
65 changes: 65 additions & 0 deletions xfuser/model_executor/pipelines/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
get_world_group,
get_runtime_state,
initialize_runtime_state,
is_dp_last_group,
get_sequence_parallel_rank,
)
from xfuser.model_executor.base_wrapper import xFuserBaseWrapper

Expand Down Expand Up @@ -392,3 +394,66 @@ def _process_cfg_split_batch_latte(
else:
raise ValueError("Invalid classifier free guidance rank")
return concat_group_0

def is_dp_last_group(self):
"""Return True if in the last data parallel group, False otherwise.
Also include parallel vae situation.
"""
if get_runtime_state().runtime_config.use_parallel_vae:
return get_world_group().rank == 0
else:
return is_dp_last_group()

def gather_broadcast_latents(self, latents:torch.Tensor):
"""gather latents from dp last group and broacast final latents
"""

# ---------gather latents from dp last group-----------
rank = get_world_group().rank
device = f"cuda:{rank}"

# all gather dp last group rank list
dp_rank_list = [torch.zeros(1, dtype=int, device=device) for _ in range(get_world_group().world_size)]
if is_dp_last_group():
gather_rank = int(rank)
else:
gather_rank = -1
torch.distributed.all_gather(dp_rank_list, torch.tensor([gather_rank],dtype=int,device=device))

dp_rank_list = [int(dp_rank[0]) for dp_rank in dp_rank_list if int(dp_rank[0])!=-1]
dp_last_group = torch.distributed.new_group(dp_rank_list)

# gather latents from dp last group
if rank == dp_rank_list[-1]:
latents_list = [torch.zeros_like(latents) for _ in dp_rank_list]
else:
latents_list = None
if rank in dp_rank_list:
torch.distributed.gather(latents, latents_list, dst=dp_rank_list[-1], group=dp_last_group)

if rank == dp_rank_list[-1]:
latents = torch.cat(latents_list,dim=0)

# ------broadcast latents to all nodes---------
src = dp_rank_list[-1]
latents_shape_len = torch.zeros(1,dtype=torch.int,device=device)

# broadcast latents shape len
if rank == src:
latents_shape_len[0] = len(latents.shape)
get_world_group().broadcast(latents_shape_len,src=src)

# broadcast latents shape
if rank == src:
input_shape = torch.tensor(latents.shape,dtype=torch.int,device=device)
else:
input_shape = torch.zeros(latents_shape_len[0],dtype=torch.int,device=device)
get_world_group().broadcast(input_shape,src=src)

# broadcast latents
if rank != src:
dtype = get_runtime_state().runtime_config.dtype
latents = torch.zeros(torch.Size(input_shape),dtype=dtype,device=device)
get_world_group().broadcast(latents,src=src)

return latents
30 changes: 21 additions & 9 deletions xfuser/model_executor/pipelines/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
is_pipeline_first_stage,
is_pipeline_last_stage,
is_dp_last_group,
get_world_group
)
from .base_pipeline import xFuserPipelineBaseWrapper
from .register import xFuserPipelineWrapperRegister
Expand Down Expand Up @@ -309,19 +310,30 @@ def __call__(
sync_only=True,
)

if is_dp_last_group():
def vae_decode(latents):
latents = self._unpack_latents(
latents, height, width, self.vae_scale_factor
)
latents = (
latents / self.vae.config.scaling_factor
) + self.vae.config.shift_factor

image = self.vae.decode(latents, return_dict=False)[0]
return image

if not output_type == "latent":
if get_runtime_state().runtime_config.use_parallel_vae:
latents = self.gather_broadcast_latents(latents)
image = vae_decode(latents)
else:
if is_dp_last_group():
image = vae_decode(latents)

if self.is_dp_last_group():
if output_type == "latent":
image = latents

else:
latents = self._unpack_latents(
latents, height, width, self.vae_scale_factor
)
latents = (
latents / self.vae.config.scaling_factor
) + self.vae.config.shift_factor

image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)

# Offload all models
Expand Down
19 changes: 15 additions & 4 deletions xfuser/model_executor/pipelines/pipeline_hunyuandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
is_dp_last_group,
is_pipeline_last_stage,
is_pipeline_first_stage,
get_world_group
)
from xfuser.model_executor.pipelines import xFuserPipelineBaseWrapper
from .register import xFuserPipelineWrapperRegister
Expand Down Expand Up @@ -454,12 +455,22 @@ def __call__(

# 8. Decode latents (only rank 0)
#! ---------------------------------------- ADD BELOW ----------------------------------------
if is_dp_last_group():
def vae_decode(latents):
image = self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False
)[0]
return image

if not output_type == "latent":
if get_runtime_state().runtime_config.use_parallel_vae:
latents = self.gather_broadcast_latents(latents)
vae_decode(latents)
else:
if is_dp_last_group():
vae_decode(latents)
gty111 marked this conversation as resolved.
Show resolved Hide resolved
if self.is_dp_last_group():
#! ---------------------------------------- ADD ABOVE ----------------------------------------
if not output_type == "latent":
image = self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False
)[0]
image, has_nsfw_concept = self.run_safety_checker(
image, device, prompt_embeds.dtype
)
Expand Down
19 changes: 15 additions & 4 deletions xfuser/model_executor/pipelines/pipeline_pixart_alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_sp_group,
is_pipeline_first_stage,
is_pipeline_last_stage,
get_world_group
)
from xfuser.model_executor.pipelines import xFuserPipelineBaseWrapper
from .register import xFuserPipelineWrapperRegister
Expand Down Expand Up @@ -359,12 +360,22 @@ def __call__(

# 8. Decode latents (only rank 0)
#! ---------------------------------------- ADD BELOW ----------------------------------------
if is_dp_last_group():
def vae_decode(latents):
image = self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False
)[0]
return image

if not output_type == "latent":
if get_runtime_state().runtime_config.use_parallel_vae:
latents = self.gather_broadcast_latents(latents)
image = vae_decode(latents)
else:
if is_dp_last_group():
image = vae_decode(latents)
if self.is_dp_last_group():
#! ---------------------------------------- ADD ABOVE ----------------------------------------
if not output_type == "latent":
image = self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False
)[0]
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(
image, orig_width, orig_height
Expand Down
21 changes: 17 additions & 4 deletions xfuser/model_executor/pipelines/pipeline_pixart_sigma.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_sp_group,
is_pipeline_first_stage,
is_pipeline_last_stage,
get_world_group
)
from .base_pipeline import xFuserPipelineBaseWrapper
from .register import xFuserPipelineWrapperRegister
Expand Down Expand Up @@ -323,11 +324,23 @@ def __call__(
)

# * 8. Decode latents (only the last rank in a dp group)
if is_dp_last_group():

def vae_decode(latents):
image = self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False
)[0]
return image

if not output_type == "latent":
if get_runtime_state().runtime_config.use_parallel_vae:
latents = self.gather_broadcast_latents(latents)
image = vae_decode(latents)
else:
if is_dp_last_group():
image = vae_decode(latents)

if self.is_dp_last_group():
if not output_type == "latent":
image = self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False
)[0]
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(
image, orig_width, orig_height
Expand Down
25 changes: 18 additions & 7 deletions xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
get_sequence_parallel_world_size,
get_sp_group,
is_dp_last_group,
get_world_group
)
from .base_pipeline import xFuserPipelineBaseWrapper
from .register import xFuserPipelineWrapperRegister
Expand Down Expand Up @@ -379,16 +380,26 @@ def __call__(
)

# * 8. Decode latents (only the last rank in a dp group)
if is_dp_last_group():

def vae_decode(latents):
latents = (
latents / self.vae.config.scaling_factor
) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
return image

if not output_type == "latent":
if get_runtime_state().runtime_config.use_parallel_vae:
latents = self.gather_broadcast_latents(latents)
image = vae_decode(latents)
else:
if is_dp_last_group():
image = vae_decode(latents)

if self.is_dp_last_group():
if output_type == "latent":
image = latents

else:
latents = (
latents / self.vae.config.scaling_factor
) + self.vae.config.shift_factor

image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)

# Offload all models
Expand Down