Skip to content

Commit

Permalink
Make flux work with diffusers 0.30.2 (xdit-project#264)
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi authored and feifeibear committed Oct 25, 2024
1 parent f7a89be commit 334eea1
Show file tree
Hide file tree
Showing 10 changed files with 388 additions and 322 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,16 @@ The overview of xDiT is shown as follows.

```
pip install xfuser
# Or optionally, with flash_attn
pip install "xfuser[flash_attn]"
```

### 2. Install from source

```
python setup.py install
pip install -e .
# Or optionally, with flash_attn
pip install -e ".[flash_attn]"
```

Note that we use two self-maintained packages:
Expand Down
42 changes: 17 additions & 25 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from setuptools import find_packages, setup
import os
import subprocess
import sys

def get_cuda_version():
try:
Expand All @@ -12,40 +10,34 @@ def get_cuda_version():
except Exception as e:
return 'no_cuda'

def get_install_requires(cuda_version):
if cuda_version == 'cu124':
sys.stderr.write("WARNING: Manual installation required for CUDA 12.4 specific PyTorch version.\n")
sys.stderr.write("Please install PyTorch for CUDA 12.4 using the following command:\n")
sys.stderr.write("pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124\n")

return [
"torch==2.3.0",
"diffusers>=0.30.0",
"transformers>=4.39.1",
"sentencepiece>=0.1.99",
"accelerate==0.33.0",
"beautifulsoup4>=4.12.3",
"distvae",
"yunchang==0.3",
"flash_attn>=2.6.3",
"pytest",
"flask",
]

if __name__ == "__main__":
with open("README.md", "r") as f:
long_description = f.read()
fp = open("xfuser/__version__.py", "r").read()
version = eval(fp.strip().split()[-1])

cuda_version = get_cuda_version()

setup(
name="xfuser",
author="xDiT Team",
author_email="fangjiarui123@gmail.com",
packages=find_packages(),
install_requires=get_install_requires(cuda_version),
install_requires=[
"torch>=2.3.0",
"accelerate==0.33.0",
"diffusers==0.30.2",
"transformers>=4.39.1",
"sentencepiece>=0.1.99",
"beautifulsoup4>=4.12.3",
"distvae",
"yunchang==0.3",
"pytest",
"flask",
],
extras_require={
"[flash_attn]": [
"flash_attn>=2.6.3",
],
},
url="https://github.com/xdit-project/xDiT.",
description="xDiT: A Scalable Inference Engine for Diffusion Transformers (DiTs) on multi-GPU Clusters",
long_description=long_description,
Expand Down
13 changes: 6 additions & 7 deletions xfuser/core/distributed/runtime_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def set_input_parameters(
self.input_config.seed = seed
set_random_seed(seed)
if (
(height and self.input_config.height != height)
not self.ready
or (height and self.input_config.height != height)
or (width and self.input_config.width != width)
or (batch_size and self.input_config.batch_size != batch_size)
or not self.ready
Expand All @@ -163,7 +164,8 @@ def set_video_input_parameters(
self.input_config.seed = seed
set_random_seed(seed)
if (
(height and self.input_config.height != height)
not self.ready
or (height and self.input_config.height != height)
or (width and self.input_config.width != width)
or (num_frames and self.input_config.num_frames != num_frames)
or (batch_size and self.input_config.batch_size != batch_size)
Expand Down Expand Up @@ -361,7 +363,6 @@ def _calc_patches_metadata(self):
self.pp_patches_token_num = pp_patches_token_num

def _calc_cogvideox_patches_metadata(self):

num_sp_patches = get_sequence_parallel_world_size()
sp_patch_idx = get_sequence_parallel_rank()
patch_size = self.backbone_patch_size
Expand Down Expand Up @@ -450,11 +451,9 @@ def _calc_cogvideox_patches_metadata(self):
pp_patches_token_start_end_idx_global = [
[
(latents_width // patch_size)
* (start_idx // patch_size)
* latents_frames,
* (start_idx // patch_size),
(latents_width // patch_size)
* (end_idx // patch_size)
* latents_frames,
* (end_idx // patch_size),
]
for start_idx, end_idx in pp_patches_start_end_idx_global
]
Expand Down
6 changes: 3 additions & 3 deletions xfuser/core/long_ctx_attention/hybrid/attn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,16 +228,16 @@ def forward(
key: Tensor,
value: Tensor,
*,
joint_tensor_query,
joint_tensor_key,
joint_tensor_value,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
joint_tensor_query=None,
joint_tensor_key=None,
joint_tensor_value=None,
joint_strategy="front",
) -> Tensor:
"""forward
Expand Down
Loading

0 comments on commit 334eea1

Please sign in to comment.