From 70b156d9d4f9aadb980f1d5281694992f6cb8d03 Mon Sep 17 00:00:00 2001 From: Xibo Sun Date: Thu, 26 Dec 2024 23:17:40 +0800 Subject: [PATCH] Parallel flux on diffusers version 0.32 (#413) --- xfuser/core/distributed/runtime_state.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/xfuser/core/distributed/runtime_state.py b/xfuser/core/distributed/runtime_state.py index 3a69728..929ed82 100644 --- a/xfuser/core/distributed/runtime_state.py +++ b/xfuser/core/distributed/runtime_state.py @@ -4,6 +4,7 @@ import numpy as np import torch +import diffusers from diffusers import DiffusionPipeline import torch.distributed @@ -121,8 +122,11 @@ def __init__(self, pipeline: DiffusionPipeline, config: EngineConfig): * pipeline.transformer.config.attention_head_dim, ) else: + vae_scale_factor = pipeline.vae_scale_factor + if pipeline.__class__.__name__.startswith("Flux") and diffusers.__version__ >= '0.32': + vae_scale_factor *= 2 self._set_model_parameters( - vae_scale_factor=pipeline.vae_scale_factor, + vae_scale_factor=vae_scale_factor, backbone_patch_size=pipeline.transformer.config.patch_size, backbone_in_channel=pipeline.transformer.config.in_channels, backbone_inner_dim=pipeline.transformer.config.num_attention_heads