Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : refactor session file management #8699

Merged
merged 7 commits into from
Jul 28, 2024

Conversation

compilade
Copy link
Collaborator

@compilade compilade commented Jul 26, 2024

Follow-up from #8526 (comment)

When updating the KV cache structure, to keep the session files working, there are at least 6 places to update (2 types of session files with read, write, and size calculation), which can be complicated to consistently do.

To make this easier to maintain, I've unified the format of the seq_id-specific session files and whole KV cache session files. They still use separate MAGIC and VERSION, but at least now they share most of the save-and-restore code.

After this, there will only be 2 places to update when making changes, the writing and the reading. No need to try get the size calculation right (see the changes to llama_state_get_size in the summary below), and no need to maintain 2 separate ways of both reading and writing the KV cache content.

Summary

  • Breaking changes
    • Saving and restoring state checks for overflow
      • The size of the buffers should now be given to the functions working with them, otherwise a truncated file could cause out of bound reads.
    • llama_state_get_size returns the actual size instead of max
      • Calculated from the same code which saves the state, by making it use a dummy data context which makes it not write to anything.
      • This is a breaking change, but makes that function much easier to keep up to date, and it also makes it reflect the behavior of llama_state_seq_get_size.
  • Improvements
    • Avoid using size_t in session files
    • Stream from session file instead of copying into a big buffer
      • Loading session files should no longer cause a memory usage spike.
    • Share code between whole and seq_id-specific state saving
      • Both session file types now use a more similar format.
    • No longer store all hparams in session files
      • Instead, the model arch name is stored. The layer count and the embedding dimensions of the KV cache are still verified when loading.
      • This makes llama_hparams no longer technically required to be trivially copyable, which should allow removing the LLAMA_MAX_LAYERS limit in a future refactor.
    • seq_id-specific session files should now also work with recurrent models like Mamba

TODO

  • Test session files with Mamba
    • llama-save-load-state
      • Works, but seq_id-specific state loading fails because slots are not contiguous (did not work before anyway)
    • llama-cli
      • Works, but rollback doesn't.
  • Test session files with a Transformer-based model
    • llama-save-load-state
    • llama-cli
      • Works, but the saved rng is not the right one, because of llama_sampling vs llama_sampling_context. Same problem as on master.
    • llama-server
      • I think this is handled by the server test suite
  • Test session files with FlashAttention
    • llama-save-load-state
    • llama-cli
    • llama-server

* llama : saving and restoring state checks for overflow

The size of the buffers should now be given to the functions working
with them, otherwise a truncated file could cause out of bound reads.

* llama : stream from session file instead of copying into a big buffer

Loading session files should no longer cause a memory usage spike.

* llama : llama_state_get_size returns the actual size instead of max

This is a breaking change, but makes that function *much* easier
to keep up to date, and it also makes it reflect the behavior
of llama_state_seq_get_size.

* llama : share code between whole and seq_id-specific state saving

Both session file types now use a more similar format.

* llama : no longer store all hparams in session files

Instead, the model arch name is stored.
The layer count and the embedding dimensions of the KV cache
are still verified when loading.
Storing all the hparams is not necessary.
@compilade compilade force-pushed the compilade/refactor-session-files branch from 7de7c17 to f1b0a1f Compare July 26, 2024 03:41
Some platforms use "%lu" and others "%llu" for uint64_t.
Not sure how to handle that, so casting to size_t when displaying errors.
@compilade compilade force-pushed the compilade/refactor-session-files branch from f1b0a1f to cddc899 Compare July 26, 2024 03:49
@compilade compilade added breaking change Changes that break ABIs, APIs, file formats, or other forms of backwards compatibility. refactoring Refactoring Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level labels Jul 26, 2024
Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job!

I guess you can drop the _context suffix for these new types:

  • llama_data_context -> llama_data
  • llama_data_read_context -> llama_data_read
  • etc.

The reason to have the suffix in llama_context, ggml_context, gguf_context is that they are top-level objects and without the suffix it would be too ambiguous in terms of symbol names (i.e. grep for ggml will match the entire codebase).

For example, llama_model is OK and does not need to be llama_model_context. Same for ggml_tensor, llama_kv_cache, llama_vocab, etc.

The llama_sampling_context in common is a mistake and I will try to fix this soon.

But even if you keep the names as they are, it's fine too

llama_state_get_size cannot be used to get the max size anymore.
@compilade
Copy link
Collaborator Author

The llama_sampling_context in common is a mistake and I will try to fix this soon.

I recently bumped into the fact that there are both llama_sampling (from src/llama-sampling.h) and llama_sampling_context (from common/sampling.h) when testing llama-cli with --prompt-cache and wondering why the RNG state didn't seem to be restored. It seems like main.cpp uses llama_sampling_context while what is saved in the session file is the rng of llama_sampling in llama_context (aka ctx->sampling.rng).

This means that the RNG save and restore for llama-cli with --prompt-cache is broken both on master and in this PR.

If llama_sampling_context is somehow available to llama_state_save_file and llama_state_load_file, I'm not sure to what extent the sampler state should be serialized (RNG, sampling params, grammar state, etc.).

I'd say this is out of the scope of this PR. Fixing the sampling state save and restore will at least be easier after this refactor if the session file format needs to be changed again.

@ggerganov
Copy link
Owner

This means that the RNG save and restore for llama-cli with --prompt-cache is broken both on master and in this PR.

Yup, I'm aware and will address this within #8643

* llama : remove LLAMA_MAX_RNG_STATE

It's no longer necessary to limit the size of the RNG state,
because the max size of session files is not estimated anymore.
@compilade compilade added the merge ready indicates that this may be ready to merge soon and is just holding out in case of objections label Jul 27, 2024
@compilade compilade merged commit 4c676c8 into master Jul 28, 2024
54 checks passed
@kaetemi
Copy link
Collaborator

kaetemi commented Jul 30, 2024

Nice! :)

arthw pushed a commit to arthw/llama.cpp that referenced this pull request Aug 2, 2024
* llama : refactor session file management

* llama : saving and restoring state checks for overflow

The size of the buffers should now be given to the functions working
with them, otherwise a truncated file could cause out of bound reads.

* llama : stream from session file instead of copying into a big buffer

Loading session files should no longer cause a memory usage spike.

* llama : llama_state_get_size returns the actual size instead of max

This is a breaking change, but makes that function *much* easier
to keep up to date, and it also makes it reflect the behavior
of llama_state_seq_get_size.

* llama : share code between whole and seq_id-specific state saving

Both session file types now use a more similar format.

* llama : no longer store all hparams in session files

Instead, the model arch name is stored.
The layer count and the embedding dimensions of the KV cache
are still verified when loading.
Storing all the hparams is not necessary.

* llama : fix uint64_t format type

* llama : various integer type cast and format string fixes

Some platforms use "%lu" and others "%llu" for uint64_t.
Not sure how to handle that, so casting to size_t when displaying errors.

* llama : remove _context suffix for llama_data_context

* llama : fix session file loading

llama_state_get_size cannot be used to get the max size anymore.

* llama : more graceful error handling of invalid session files

* llama : remove LLAMA_MAX_RNG_STATE

It's no longer necessary to limit the size of the RNG state,
because the max size of session files is not estimated anymore.

* llama : cast seq_id in comparison with unsigned n_seq_max
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
breaking change Changes that break ABIs, APIs, file formats, or other forms of backwards compatibility. examples merge ready indicates that this may be ready to merge soon and is just holding out in case of objections refactoring Refactoring Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants