Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about varlen ring attention #6

Closed
TechxGenus opened this issue Jan 15, 2025 · 6 comments
Closed

Question about varlen ring attention #6

TechxGenus opened this issue Jan 15, 2025 · 6 comments

Comments

@TechxGenus
Copy link

Thanks for sharing this great resource.
I am studying the varlen form of ring attention mentioned in the paper, which can indeed avoid the padding problem in TE implementation. However, it seems to me that it will have load imbalance problems, and some CP nodes may require too much computation. How do you consider this problem?

@MiniMax-AI-Dev
Copy link
Contributor

Yes, we have noticed that load imbalance can occur in Varlen Ring Attention. However, this issue is not specific to Varlen Ring Attention alone but is rather a result of the "data-packing + varlen" approach. When ring attention is not used, this approach can lead to load imbalance in data parallelism (DP), where some DPs may end up with concatenated of short sequences while others have complete long sequences. This causes the short-sequence DPs to be forced to wait during synchronization.

In the case of Varlen Ring Attention, this impact extends to the synchronization communication of context parallelism (CP). To address this problem, it is necessary to avoid mixing long and short sequences within the same micro-batch training process. Theoretically, if needed, one could manually adjust the training order of samples with different sequence lengths within the global batch to prevent load imbalance. However, in practice, this adjustment can be challenging because the total number of tokens in the global batch is fixed. In scenarios with long sequences, the number of samples is very small, leaving little room to adjust the load. Therefore, solving this issue requires collaboration with the data side as well.

@TechxGenus
Copy link
Author

Thanks, it does help to unify the distribution of sequence lengths in each mini-batch.
While my question is about the imbalance caused by causal mask when training longer sequences, which is a separate issue of the ring attention mechanism (ref: zhuzilin/ring-flash-attention#2), and becomes difficult to handle when combined with packing.

@MiniMax-AI-Dev
Copy link
Contributor

In this context, the implementation approach we are referring to is the Zig-Zag method, which is also used in TransformerEngine.

Implementing this method in the context of data-packing is indeed troublesome.

Image

@TechxGenus
Copy link
Author

Get it. Thanks for the detailed answer.

@hhaAndroid
Copy link

Does implementing this feature require directly modifying the source code in Flash Attention, or can it be achieved by calling internal interfaces? Thank you

@Infi-zc
Copy link

Infi-zc commented Feb 12, 2025

Does implementing this feature require directly modifying the source code in Flash Attention, or can it be achieved by calling internal interfaces? Thank you

Hi, @hhaAndroid have you reproduced this scheduling algorithm? How is the effect?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants