Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Investigate and implement Flash Attention #5

Closed
0x000011b opened this issue Feb 19, 2023 · 2 comments
Closed

Investigate and implement Flash Attention #5

0x000011b opened this issue Feb 19, 2023 · 2 comments
Assignees
Labels
enhancement New feature or request

Comments

@0x000011b
Copy link

For our use-case of fine-tuning LMs on up to 2048 tokens, flash attention might get us a ~2-4x speedup and a VRAM usage reduction of up to 10x. Sounds pretty amazing, so I'd like to give it a shot. Some code inspiration:

@0x000011b 0x000011b converted this from a draft issue Feb 19, 2023
@0x000011b 0x000011b added the enhancement New feature or request label Feb 19, 2023
@0x000011b 0x000011b moved this from 📋 Backlog to 🏗 In progress in AI/ML Model Backlog Mar 3, 2023
@0x000011b
Copy link
Author

Me and @TearGosling played around with optimizing NeoX by using components from xFormers after I profiled the training code. Results:

  • Usage of the MLP component there causes performance drop and lots of warnings, so we're ignoring this one
  • Usage of flash attention results in:
    • 👍 A decent speedup (~17% IIRC)
    • 👎 A significant increase in training and evaluation loss
    • 😐 No noticeable VRAM savings
  • Usage of the rotary embedding implementation results in:
    • 👍 A decent speedup (when used together with flash attention, gets us over a 20% throughput increase IIRC)
    • 😐 A very minor increase in training loss

Next steps:

  • Figure out whether the higher loss values with flash attention actually translate to a worse model, or whether they can be accounted for by some other factor we're unaware of.
  • No noticeable VRAM savings with flash attention doesn't make a lot of sense. Investigate whether this is due to e.g. memory fragmentation (since we're not properly pre-allocating memory for dynamic tensor sizes).

@0x000011b
Copy link
Author

Flash attention via xFormers landed on #8. "Lack" of VRAM savings was actually just fragmentation messing us up.

@github-project-automation github-project-automation bot moved this from 🏗 In progress to ✅ Done in AI/ML Model Backlog May 2, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
Status: Done
Development

No branches or pull requests

2 participants