-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama flash attn #86
Llama flash attn #86
Conversation
|
fails with same error on 48GB card without 8-bit and without LoRA. |
…tention is done, so early on. NOTE: flash attention requires installing cuda 11.7 via https://developer.nvidia.com/cuda-11-7-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=runfile_local and then when running, to avoid installing driver, docs, samples, just install toolkit. Then when pip installing flash attention do: CUDA_HOME=/usr/local/cuda-11.7 pip install flash-attn
On 24GB board can run 7B 16-bit with 512 cutoff and it uses 92% of memory.
See also:
BetterTransformer has built-in support for flash attention now via torch updates, but only for certain parts of the operations. We can try: https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2 To run:
|
Training works fine:
Will be done with 6 epochs in 5 hours. Training 20B was 12s/iteration, this 30B llama is 2iteration/s (note the inversion). So 24x faster for 50% larger model. |
No crash during checkpoint at save_steps=2000
No crash with lora 16 either with 0.01 epochs. So all good. |
Non-test run on 31eef24 The below is content of llama30b.sh, ran as:
To ensure roughly checkpoint every epoch with:
Roughly 8 hours:
|
This OOMs:
|
So far not OOMing with just 512. ETA 8 hours.
Much lower loss:
_7 case with eval: |
8-bit will take 30 hours. Won't run this one.
Just model is 42% of 80GB:
|
8-bit with larger batch leads to more efficiency for GPUs:
|
Refactor finetune so some of it can be used to check data and its tokenization
…runcation rows to avoid learning from truncated language
OIGw/OASST mixed in with pure OASST at equal level. ETA about 22 hours for 2 epochs, which is about equivalent to 20 epochs for pure OASST for the OASST part of the data. So checkpoints are every 0.25 epochs in order to perform later eval and choose best model.
Early logs: At about 1 epoch (yellow line): 1.6 epochs: At exactly 2 epochs when it would be done, hit odd error:
|
llama 7B: 16, 32, 64 batch size OOMs, 16 almost survived but died soon. 32 and 64 died early
epoch ~8:
Purple one, about loss=0.5 by end.
|
For llama 30B _7: default generate eval: score_llama30B_jon7h.log left is GPT-3.5 turbo, right is 30B llama _7: more creative:
Still much worse than GPT3.5 turbo: |
llama sensitive to non-helpful training data in oasst raw data:
|
Note:
Requires A100+ to work OOTB without patching transformers or using nightly torch to avoid errors.
Would need to patch transformers lm-sys/FastChat#581 (comment)
We could also try bettertransformers wrapper, which wraps HF models into BT models, using native torch version of flash attention (they say slower but more memory efficient):
https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#:~:text=Scaled%20dot%20product%20attention%20attempts,for%20enabling%20and%20disabling%20implementations.
That might require pytorch nightly to fix a related bug where they should have disabled fast attention for some head sizes if don't have sm80. But we can then at least test on A6000/4090. But should then all work on A100 and use flash attention to some degree without limit on head sizes.
See also:
pytorch/pytorch#99105
pytorch/pytorch#98771
pytorch/pytorch#98140
huggingface/transformers#18439
lm-sys/FastChat#459
pytorch/pytorch#94883