Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama/ggml: add LLM training support #10544

Open
wants to merge 12 commits into
base: master
Choose a base branch
from

Conversation

JohannesGaessler
Copy link
Collaborator

@JohannesGaessler JohannesGaessler commented Nov 27, 2024

See ggerganov/ggml#1025 except I decided to implement the training directly in llama.cpp after all because the GPT-2 GGML example is already pretty complex, would require a significant amount of effort to refactor, and I'm not familiar with the codebase at all.

The goal of this PR is to add general training support to llama.cpp using ggml_opt. CPU training seems to work, other backends are missing support for some GGML ops. It's currently not possible to actually save the finetuned model to disk but you can confirm that the finetuning works by doing one epoch over the input text prior to perplexity calculation (or by observing how the loss goes down with the new finetune example). One epoch over the test set of Wikitext-2 (with the stride chosen in such a way that each token is used twice per epoch) currently takes ~1 minute with Stories 260k or ~20 hours and ~100 GB RAM with LLaMA 3 8b. For the user-facing API my concrete plans are:

  • The parameter n_ctx determines the max. sequence length with which the model is trained.
  • The parameter n_batch determines how many tokens are consumed per optimizer step.
  • The parameter n_ubatch determines the number of tokens in parallel, enables speed <-> memory use tradeoff, should have no effect on the result except for differences in floating point rounding error.
  • A function with which the user can initialize a dataset from type std::vector<llama_token>. Currently I have this as part of llama.h but maybe this would make more sense to put in common.h?
  • A function llama_opt_init that prepares a llama_context for training and lets the user define things like the learning rate or which tensors should be trainable parameters.
  • A function llama_opt_epoch that performs one epoch over a ggml_opt_dataset, equivalent to ggml_opt_epoch.
  • Maybe a function like llama_opt_fit equivalent to ggml_opt_fit that is even more high-level?

Currently, while functional, the PR is in a bad state in terms of software design and is in need of a refactor. The reason I'm already opening it now is because I want to ask for advice regarding how to best implement llama_opt_epoch. My current approach was to try and hijack the first half of llama_decode_internal but I found that in the end all I needed from it was the generation of the next llama_ubatch and the corresponding manipulation of the KV cache. But maybe it would make more sense to instead write a function like llama_prepare_next_ubatch and to use that function in llama_decode_internal and llama_opt_epoch?

@github-actions github-actions bot added testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs examples ggml changes relating to the ggml tensor library for machine learning labels Nov 27, 2024
@JohannesGaessler JohannesGaessler added the Review Complexity : High Generally require indepth knowledge of LLMs or GPUs label Nov 27, 2024
@JohannesGaessler JohannesGaessler marked this pull request as ready for review December 1, 2024 23:15
@JohannesGaessler
Copy link
Collaborator Author

I pushed a version that I think is in a state where it could be merged.

  • I refactored llama_decode_internal and split off functions llama_prepare_sbatch and llama_prepare_ubatch that can be called from llama_opt_epoch.
  • ggml training now has calls ggml_opt_alloc and ggml_opt_eval instead of ggml_opt_forward and ggml_opt_forward_backward. When not using static graphs a call to ggml_opt_prepare_alloc is also needed to provide a new forward graph.
  • I added a function llama_save_model_to_file for converting a llama_model to a GGUF file. For finetuning it would have been possible to copy a lot of the data from the input file but for training a model from scratch a method like this will be needed anyways. Currently tensors with non-CPU data cause a segfault when passed to the GGUF code, see GGUF: ggml backend support for writing tensor data ggml#1033 .
  • To control which tensors should be trainable parameters the user can pass a function that filters the tensors in a model.

My immediate next goals will be:

  • Fixing GGUF for non-CPU tensors.
  • CUDA support for the operations missing for training.
  • Support for FP16/BF16.

On a somewhat related note, it may make sense to refactor the file llama.cpp in such a way that moves code to other files; for some cases my IDE is starting to get a little sluggish when working on a 22k LOC file.

@lexasub
Copy link

lexasub commented Dec 29, 2024

@JohannesGaessler you may see #10902

@JohannesGaessler
Copy link
Collaborator Author

The link doesn't work.

@lexasub
Copy link

lexasub commented Dec 30, 2024

@JohannesGaessler sorry, #10902

@JohannesGaessler
Copy link
Collaborator Author

I've started working on this again, I rebased my local branch onto master and am currently adding the missing ops for CUDA training. This PR is getting quite large; in terms of reviewing, would you prefer if I split off things that can be reviewed and merged on their own?

@ggerganov
Copy link
Owner

If you can separate things in standalone PRs, it's always helpful (maybe the CUDA ops can be in a standalone PR).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
examples ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs Review Complexity : High Generally require indepth knowledge of LLMs or GPUs testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants