Skip to content

Commit

Permalink
Merge pull request #31 from h2oai/flash-attn
Browse files Browse the repository at this point in the history
neox Flash attn
  • Loading branch information
arnocandel authored May 11, 2023
2 parents a491987 + cc43cc5 commit 1e1540e
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
2 changes: 0 additions & 2 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,6 @@ def train(
NOTE: for current pytorch 2.0, flash attention requires installing cuda 11.7 via https://developer.nvidia.com/cuda-11-7-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=runfile_local and then when running, to avoid installing driver, docs, samples, just install toolkit. Then when pip installing flash attention do:
CUDA_HOME=/usr/local/cuda-11.7 pip install flash-attn""")
from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
assert (
base_model
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# for generate (gradio server) and finetune
datasets==2.11.0
datasets==2.12.0
sentencepiece==0.1.97
accelerate==0.18.0
gradio==3.27.0
huggingface_hub==0.13.4
huggingface_hub==0.14.1
appdirs==1.4.4
fire==0.5.0
docutils==0.19
torch==2.0.0
torch==2.0.1
evaluate==0.4.0
rouge_score==0.1.2
sacrebleu==2.3.1
Expand Down

0 comments on commit 1e1540e

Please sign in to comment.