Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CogVideoX support #261

Merged
merged 11 commits into from
Sep 12, 2024
Merged

Fix CogVideoX support #261

merged 11 commits into from
Sep 12, 2024

Conversation

chengzeyi
Copy link
Collaborator

@chengzeyi chengzeyi commented Sep 10, 2024

  1. Make flash_attn dependency optional when installing.
  2. Do not force a fixed torch version in setup.py.
  3. Add xFuserCogVideoXAttnProcessor2_0 to support CogVideoXAttnProcessor2_0 in newer version of diffusers.
  4. Support latest diffusers (use a bundled apply_rotary_emb).
  5. CogVideoX now supports ulysses sequence parallel.

Current status

torchrun --nproc_per_node=2 examples/cogvideox_example.py --ulysses_degree 2 \
    --model THUDM/CogVideoX-5b --height 480 --width 720 --num_frames 30 \
    --prompt "a panda playing piano"

Works perfectly now.

],
extras_require={
"all": [
"flash_attn>=2.6.3",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should update readme.

pip install xfuser [flash_attn]?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I changed all to flash_attn and modified the readme.

@feifeibear feifeibear merged commit 9484590 into xdit-project:main Sep 12, 2024
@@ -12,40 +12,34 @@ def get_cuda_version():
except Exception as e:
return 'no_cuda'

def get_install_requires(cuda_version):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The func was introduced in #259, why remove it?

Copy link
Collaborator Author

@chengzeyi chengzeyi Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This package (xdit) does not require a specific CUDA version to build or run.
So this warning is meaningless.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And even this commit #259 is buggy. It actually wants to print a warning when the version is not 12.4. but actually it prints a warning when the version is equal to 12.4...😢

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way it checks the cuda version is also incorrect. It checks the version of the system cuda rather than the cuda bundled with the PyTorch installation. So the get_cuda_version func could be removed in a future commit.

def torch_compile_disable_if_v100(func):
if is_v100():
return torch.compiler.disable(func)
return func


def apply_rotary_emb(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change result in the following error:

torchrun --nproc_per_node=2 ./examples/flux_example.py --model /cfs/dit/FLUX.1-dev --ulysses_degree 2 --prompt "A snowy mountain" --num_inference_steps 20

[rank0]: Traceback (most recent call last):
[rank0]:   File "~/xDiT/./examples/flux_example.py", line 77, in <module>
[rank0]:     main()
[rank0]:   File "~/xDiT/./examples/flux_example.py", line 35, in main
[rank0]:     pipe.prepare_run(input_config)
[rank0]:   File "~/xDiT/xfuser/model_executor/pipelines/pipeline_flux.py", line 69, in prepare_run
[rank0]:     self.__call__(
[rank0]:   File "~/miniconda3/envs/long_ctx_attn/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "~/xDiT/xfuser/model_executor/pipelines/base_pipeline.py", line 181, in wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "~/xDiT/xfuser/model_executor/pipelines/base_pipeline.py", line 133, in data_parallel_fn
[rank0]:     return func(self, *args, **kwargs)
[rank0]:   File "~/xDiT/xfuser/model_executor/pipelines/base_pipeline.py", line 149, in check_naive_forward_fn
[rank0]:     return func(self, *args, **kwargs)
[rank0]:   File "~/xDiT/xfuser/model_executor/pipelines/pipeline_flux.py", line 297, in __call__
[rank0]:     latents = self._sync_pipeline(
[rank0]:   File "~/xDiT/xfuser/model_executor/pipelines/pipeline_flux.py", line 399, in _sync_pipeline
[rank0]:     latents, encoder_hidden_states = self._backbone_forward(
[rank0]:   File "~/xDiT/xfuser/model_executor/pipelines/pipeline_flux.py", line 484, in _backbone_forward
[rank0]:     noise_pred, encoder_hidden_states = self.transformer(
[rank0]:   File "~/miniconda3/envs/long_ctx_attn/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "~/miniconda3/envs/long_ctx_attn/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "~/xDiT/xfuser/model_executor/models/transformers/transformer_flux.py", line 147, in forward
[rank0]:     encoder_hidden_states, hidden_states = block(
[rank0]:   File "~/miniconda3/envs/long_ctx_attn/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "~/miniconda3/envs/long_ctx_attn/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "~/miniconda3/envs/long_ctx_attn/lib/python3.10/site-packages/diffusers/models/transformers/transformer_flux.py", line 200, in forward
[rank0]:     attn_output, context_attn_output = self.attn(
[rank0]:   File "~/miniconda3/envs/long_ctx_attn/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "~/miniconda3/envs/long_ctx_attn/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "~/xDiT/xfuser/model_executor/layers/attention_processor.py", line 223, in forward
[rank0]:     return self.processor(
[rank0]:   File "~/xDiT/xfuser/model_executor/layers/attention_processor.py", line 705, in __call__
[rank0]:     query = apply_rotary_emb(query, image_rotary_emb)
[rank0]:   File "~/xDiT/xfuser/model_executor/layers/attention_processor.py", line 76, in apply_rotary_emb
[rank0]:     cos, sin = freqs_cis  # [S, D]
[rank0]: ValueError: not enough values to unpack (expected 2, got 1)

We can merge this pr again after the problem is fixed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, looks like the latest diffusers has changed a lot and make itself incompatible with the flux implementation in xdit anymore.
I guess a refactor could be needed.

Copy link
Collaborator Author

@chengzeyi chengzeyi Sep 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After modifying the attention processor and remove resuing encoder_hidden_states it now works with diffusers==0.30.2.
Anyway it is still only for reference for you. I guess the frequent changes in diffusers and lack of some automatic code patching mechanism could be challenging to fit. So those changes need to be trated carefully.🙂

#264

Eigensystem added a commit to Eigensystem/xDiT that referenced this pull request Sep 12, 2024
Eigensystem added a commit that referenced this pull request Sep 12, 2024
feifeibear pushed a commit to feifeibear/xDiT that referenced this pull request Oct 25, 2024
feifeibear pushed a commit to feifeibear/xDiT that referenced this pull request Oct 25, 2024
feifeibear pushed a commit to feifeibear/xDiT that referenced this pull request Oct 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants