You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
For our use-case of fine-tuning LMs on up to 2048 tokens, flash attention might get us a ~2-4x speedup and a VRAM usage reduction of up to 10x. Sounds pretty amazing, so I'd like to give it a shot. Some code inspiration:
diffusers PR 532: shows how to use flash attention through xformers, probably the least painful way to go about it
GPT-NeoX PR 725: flash attention implementation in Eleuther's fork of Megatron
Me and @TearGosling played around with optimizing NeoX by using components from xFormers after I profiled the training code. Results:
Usage of the MLP component there causes performance drop and lots of warnings, so we're ignoring this one
Usage of flash attention results in:
👍 A decent speedup (~17% IIRC)
👎 A significant increase in training and evaluation loss
😐 No noticeable VRAM savings
Usage of the rotary embedding implementation results in:
👍 A decent speedup (when used together with flash attention, gets us over a 20% throughput increase IIRC)
😐 A very minor increase in training loss
Next steps:
Figure out whether the higher loss values with flash attention actually translate to a worse model, or whether they can be accounted for by some other factor we're unaware of.
No noticeable VRAM savings with flash attention doesn't make a lot of sense. Investigate whether this is due to e.g. memory fragmentation (since we're not properly pre-allocating memory for dynamic tensor sizes).
For our use-case of fine-tuning LMs on up to 2048 tokens, flash attention might get us a ~2-4x speedup and a VRAM usage reduction of up to 10x. Sounds pretty amazing, so I'd like to give it a shot. Some code inspiration:
The text was updated successfully, but these errors were encountered: