diff --git a/docs/performance/flux.md b/docs/performance/flux.md
index e93e8e5..b906b3a 100644
--- a/docs/performance/flux.md
+++ b/docs/performance/flux.md
@@ -17,8 +17,10 @@ Since Flux.1 does not utilize Classifier-Free Guidance (CFG), it is not compatib
We conducted performance benchmarking using FLUX.1 [dev] with 28 diffusion steps.
The following figure shows the scalability of Flux.1 on two 8xL40 Nodes, 16xL40 GPUs in total.
-Consequently, the performance improvement dose not achieved with 16 GPUs, and for 1024px and 2048px tasks.
+Althogh cfg parallel is not available, We can still achieve enhanced scalability by using PipeFusion as a method for parallel between nodes.
+For the 1024px task, hybrid parallel on 16xL40 is 1.16x lower than on 8xL40, where the best configuration is ulysses=4 and pipefusion=4.
For the 4096px task, hybrid parallel still benefits on 16 L40s, 1.9x lower than 8 GPUs, where the configuration is ulysses=2, ring=2, and pipefusion=4.
+The performance improvement dose not achieved with 16 GPUs 2048px tasks.
+
+
+
+下图展示了Flux.1在8xA100 GPU上的可扩展性。
+对于1024px和2048px的图像生成任务,SP-Ulysses在单一并行方法中表现出最低的延迟。在这种情况下,最佳混合策略也是SP-Ulysses。
+
+
+
+
+
+注意,上图所示的延迟尚未包括使用torch.compile,这将提供进一步的性能改进。
+
+### Flux.1 Schnell的扩展性
+我们使用FLUX.1 [schnell]进行了性能基准测试,采用4个扩散步骤。
+由于扩散步骤非常少,我们不使用PipeFusion。
在8xA100 (80GB) NVLink互联的机器上,生成1024px图片,USP最佳策略是把所有并行度都给Ulysses,使用torch.compile之后的生成1024px图片仅需0.82秒!
@@ -54,7 +75,7 @@ xDiT还不支持Flux.1使用PipeFusion,因为schnell版本采样步数太少
alt="latency-flux_l40_2k">
-### VAE Parallel
+### VAE并行
在A100上,单卡使用Flux.1超过2048px就会OOM。这是因为Activation内存需求增加,同时卷积算子引发memory spike,二者共同导致的。
@@ -68,3 +89,4 @@ prompt是"A hyperrealistic portrait of a weathered sailor in his 60s, with deep-
+
diff --git a/setup.py b/setup.py
index 452ce12..beb2c20 100644
--- a/setup.py
+++ b/setup.py
@@ -37,6 +37,8 @@ def get_cuda_version():
"pytest",
"flask",
"opencv-python",
+ "imageio",
+ "imageio-ffmpeg",
],
extras_require={
"flash_attn": [
diff --git a/xfuser/model_executor/pipelines/pipeline_cogvideox.py b/xfuser/model_executor/pipelines/pipeline_cogvideox.py
index 5d04352..d686b43 100644
--- a/xfuser/model_executor/pipelines/pipeline_cogvideox.py
+++ b/xfuser/model_executor/pipelines/pipeline_cogvideox.py
@@ -226,7 +226,9 @@ def __call__(
max_sequence_length=max_sequence_length,
device=device,
)
- prompt_embeds = self._process_cfg_split_batch_latte(prompt_embeds, negative_prompt_embeds)
+ prompt_embeds = self._process_cfg_split_batch_latte(
+ prompt_embeds, negative_prompt_embeds
+ )
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
@@ -253,7 +255,9 @@ def __call__(
# 7. Create rotary embeds if required
image_rotary_emb = (
- self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
+ self._prepare_rotary_positional_embeddings(
+ height, width, latents.size(1), device
+ )
if self.transformer.config.use_rotary_positional_embeddings
else None
)
@@ -263,7 +267,9 @@ def __call__(
len(timesteps) - num_inference_steps * self.scheduler.order, 0
)
- latents, image_rotary_emb = self._init_sync_pipeline(latents, image_rotary_emb, latents.size(1))
+ latents, image_rotary_emb = self._init_sync_pipeline(
+ latents, image_rotary_emb, latents.size(1)
+ )
with self.progress_bar(total=num_inference_steps) as progress_bar:
# for DPM-solver++
old_pred_original_sample = None
@@ -296,7 +302,18 @@ def __call__(
# perform guidance
if use_dynamic_cfg:
self._guidance_scale = 1 + guidance_scale * (
- (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
+ (
+ 1
+ - math.cos(
+ math.pi
+ * (
+ (num_inference_steps - t.item())
+ / num_inference_steps
+ )
+ ** 5.0
+ )
+ )
+ / 2
)
if do_classifier_free_guidance:
if get_classifier_free_guidance_world_size() == 1:
@@ -339,7 +356,9 @@ def __call__(
"negative_prompt_embeds", negative_prompt_embeds
)
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ if i == len(timesteps) - 1 or (
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
+ ):
progress_bar.update()
if get_sequence_parallel_world_size() > 1:
@@ -377,14 +396,22 @@ def _init_sync_pipeline(
image_rotary_emb = (
torch.cat(
[
- image_rotary_emb[0].reshape(latents_frames, -1, d)[:, start_token_idx:end_token_idx].reshape(-1, d)
+ image_rotary_emb[0]
+ .reshape(latents_frames, -1, d)[
+ :, start_token_idx:end_token_idx
+ ]
+ .reshape(-1, d)
for start_token_idx, end_token_idx in get_runtime_state().pp_patches_token_start_end_idx_global
],
dim=0,
),
torch.cat(
[
- image_rotary_emb[1].reshape(latents_frames, -1, d)[:, start_token_idx:end_token_idx].reshape(-1, d)
+ image_rotary_emb[1]
+ .reshape(latents_frames, -1, d)[
+ :, start_token_idx:end_token_idx
+ ]
+ .reshape(-1, d)
for start_token_idx, end_token_idx in get_runtime_state().pp_patches_token_start_end_idx_global
],
dim=0,