-
Notifications
You must be signed in to change notification settings - Fork 10.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama/ggml: add LLM training support #10544
base: master
Are you sure you want to change the base?
Conversation
I pushed a version that I think is in a state where it could be merged.
My immediate next goals will be:
On a somewhat related note, it may make sense to refactor the file |
@JohannesGaessler you may see #10902 |
The link doesn't work. |
@JohannesGaessler sorry, #10902 |
I've started working on this again, I rebased my local branch onto master and am currently adding the missing ops for CUDA training. This PR is getting quite large; in terms of reviewing, would you prefer if I split off things that can be reviewed and merged on their own? |
If you can separate things in standalone PRs, it's always helpful (maybe the CUDA ops can be in a standalone PR). |
See ggerganov/ggml#1025 except I decided to implement the training directly in llama.cpp after all because the GPT-2 GGML example is already pretty complex, would require a significant amount of effort to refactor, and I'm not familiar with the codebase at all.
The goal of this PR is to add general training support to llama.cpp using
ggml_opt
. CPU training seems to work, other backends are missing support for some GGML ops. It's currently not possible to actually save the finetuned model to disk but you can confirm that the finetuning works by doing one epoch over the input text prior to perplexity calculation (or by observing how the loss goes down with the new finetune example). One epoch over the test set of Wikitext-2 (with the stride chosen in such a way that each token is used twice per epoch) currently takes ~1 minute with Stories 260k or ~20 hours and ~100 GB RAM with LLaMA 3 8b. For the user-facing API my concrete plans are:n_ctx
determines the max. sequence length with which the model is trained.n_batch
determines how many tokens are consumed per optimizer step.n_ubatch
determines the number of tokens in parallel, enables speed <-> memory use tradeoff, should have no effect on the result except for differences in floating point rounding error.std::vector<llama_token>
. Currently I have this as part ofllama.h
but maybe this would make more sense to put incommon.h
?llama_opt_init
that prepares allama_context
for training and lets the user define things like the learning rate or which tensors should be trainable parameters.llama_opt_epoch
that performs one epoch over aggml_opt_dataset
, equivalent toggml_opt_epoch
.llama_opt_fit
equivalent toggml_opt_fit
that is even more high-level?Currently, while functional, the PR is in a bad state in terms of software design and is in need of a refactor. The reason I'm already opening it now is because I want to ask for advice regarding how to best implement
llama_opt_epoch
. My current approach was to try and hijack the first half ofllama_decode_internal
but I found that in the end all I needed from it was the generation of the nextllama_ubatch
and the corresponding manipulation of the KV cache. But maybe it would make more sense to instead write a function likellama_prepare_next_ubatch
and to use that function inllama_decode_internal
andllama_opt_epoch
?