forked from hiyouga/LLaMA-Factory
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_assist.sh
68 lines (59 loc) · 1.86 KB
/
run_assist.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
TIME=$(date "+%m-%d-%H-%M")
TASK=steps
DATASET=assist_$TASK
TEMPLATE=$TASK
# wandb
export WANDB_PROJECT=grace-assist-$TASK
BASE_MODEL=mistral
# BASE_MODEL=llama
if [ $BASE_MODEL == "mistral" ]; then
OUTPUT_DIR=~/models/assist/mistral-7b-$TEMPLATE-$DATASET-$TIME
MODEL_NAME_OR_PATH="/data/cache/huggingface/hub/models--mistralai--Mistral-7B-v0.1/snapshots/5e9c98b96d071dce59368012254c55b0ec6f8658"
elif [ $BASE_MODEL == "llama" ]; then
OUTPUT_DIR=~/models/assist/llama-7b-$TEMPLATE-$DATASET-$TIME
MODEL_NAME_OR_PATH=/data/cache/huggingface/hub/models--meta-llama--Llama-2-7b-hf/snapshots/6fdf2e60f86ff2481f2241aaee459f85b5b0bbb9
fi
VAL_SIZE=0.01
NUM_GPUS=8
# not sure, test the LR
# LR=5e-5
LR=5e-7
EPOCHS=3
CUTOFF_LEN=4096
# accelerate launch src/train_bash.py \
# deepspeed --hostfile hostfile.txt src/train_bash.py \
deepspeed --num_gpus $NUM_GPUS --master_port=9901 src/train_bash.py \
--deepspeed "/root/LLaMA_Factory/LLaMA-Factory/ds_config.json" \
--stage sft \
--model_name_or_path $MODEL_NAME_OR_PATH \
--do_train True \
--overwrite_cache False \
--finetuning_type full \
--template $TEMPLATE \
--dataset_dir data \
--dataset $DATASET \
--cutoff_len $CUTOFF_LEN \
--learning_rate $LR \
--num_train_epochs $EPOCHS \
--max_samples 100000 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--max_grad_norm 1.0 \
--logging_steps 5 \
--save_steps 100 \
--warmup_steps 0 \
--lora_rank 8 \
--lora_dropout 0.1 \
--lora_target q_proj,v_proj \
--resume_lora_training True \
--output_dir $OUTPUT_DIR \
--fp16 \
--plot_loss True \
--val_size $VAL_SIZE \
--evaluation_strategy steps \
--eval_steps 10 \
--report_to wandb \
--flash_attn
# --load_best_model_at_end True