Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : remove all_pos_0, all_pos_1, all_seq_id from llama_batch #9745

Merged
merged 15 commits into from
Oct 18, 2024

Conversation

ngxson
Copy link
Collaborator

@ngxson ngxson commented Oct 4, 2024

Motivation

While working on the ability to add both embeddings and tokens to the same batch, I noticed that the old API for llama_batch, namely all_pos_0, all_post_1 and all_seq_id has been there for quite a long time.

Migration guide

The recommended way is to use llama_batch_init and llama_batch_free:

llama_batch batch = llama_batch_init(n_tokens, 0, 1); // allocate a batch of n_tokens and one sequence ID
batch.n_tokens = n_tokens;
for (int i = 0; i < n_tokens; i++) {
    batch. token[i] = tokens[i]; // copy token into batch
    batch.   pos[i] = n_past + i; // set correct position for each token
    batch.seq_id[i][0] = 0; // all tokens are in sequence 0
}
batch.logits[n_tokens - 1] = true; // only get logits for last token

if (llama_decode(ctx, batch)) {
    LOG_ERR("%s : failed to eval\n", __func__);
    llama_batch_free(batch); // remember to free the batch before returning
    return false;
}

llama_batch_free(batch);

If the binary is linked against common, you can use some helper functions:

  • common_batch_add to add a new token into the batch
  • common_batch_clear to remove all tokens from the batch

If your use case is using single sequence, then you can adapt to the new call signature of llama_batch_get_one (although, this is not recommended):

if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
    LOG_ERR("%s : failed to eval\n", __func__);
    return false;
}

The position of tokens will be tracked automatically by llama_decode. For example, if the first time, you call llama_decode on a batch of 10 tokens, then the next time llama_decode will start decoding from position 11.


@ngxson ngxson added the breaking change Changes that break ABIs, APIs, file formats, or other forms of backwards compatibility. label Oct 4, 2024
@ngxson ngxson requested a review from ggerganov October 4, 2024 21:23
@slaren
Copy link
Collaborator

slaren commented Oct 4, 2024

I don't see a clear motivation for removing this. I believe that single sequence usage is by far the most common way llama.cpp is used, and removing this function will require most applications to add a lot of boilerplate. We should aim to make the llama.cpp API as simple as possible to use.

@github-actions github-actions bot added the android Issues specific to Android label Oct 4, 2024
@ngxson
Copy link
Collaborator Author

ngxson commented Oct 4, 2024

My main motivation for this PR is that instead of having an API call solely for keeping backward-compatibility, we could keep it as an utility, not a core API.

Second motivation is thatllama_decode accepts a batch of multiple sequences. It doesn't really care if in the batch there is 1 or many sequences. Therefore, the llama_batch should reflect that.

Keeping these backward-compat struct member makes the code inside llama_sbatch.add_seq_to_ubatch to have 2 different branches that (almost) does the same thing. Instead of maintaining these 2 if-else branches, I believe the better solution would be to generate all the pos, n_seq_id, logits from all_pos_0, all_post_1, all_seq_id. In other words, one can be computed from the other:

(all_pos_0, all_pos_1, all_seq_id) --> (pos, n_seq_id, logits)

@ngxson
Copy link
Collaborator Author

ngxson commented Oct 4, 2024

I believe that single sequence usage is by far the most common way llama.cpp is used, and removing this function will require most applications to add a lot of boilerplate.

I think in this use case, simple specify n_tokens, token, pos is enough, as most users never use more than one sequence. Look this way, all_pos_0, all_pos_1, all_seq_id are still redundant.

So if we really want to simplify the usage for end user, we could allow user to only set n_tokens, token, pos in llama_batch; The sequence ID can be forced to 0 in this case (as in all examples, llama_batch_get_one is always used with seq_id=0)

Even more simple, pos can be tracked internally by looking at KV cache, so user can simply input a list of tokens.

@slaren
Copy link
Collaborator

slaren commented Oct 4, 2024

There is a lot we could do to simplify the llama_batch API, I think that the positions could be removed entirely from the public API, but that should wait until a redesign of the API.

all_pos_0, all_pos_1, all_seq_id are not truly redundant when the alternative requires allocating an array of positions, which requires a lot more code than simply calling llama_batch_get_one. If you want to simplify the code that deals with llama_batch internally, a simple solution would be to make a function called at the start of llama_decode that transforms these fields to a list of positions, and sets the pos field of llama_batch. The same could be done to remove all_logits. Then you could remove all the code in llama.cpp that has to deal with all_pos_0, all_pos_1, all_seq_id, single sequence users could continue using llama_batch_get_one without using the common library (which we should never encourage), and it would avoid a breaking API change.

@ngxson
Copy link
Collaborator Author

ngxson commented Oct 5, 2024

all_pos_0, all_pos_1, all_seq_id are not truly redundant when the alternative requires allocating an array of positions, which requires a lot more code than simply calling llama_batch_get_one.

Let me clarify a bit more, what I mean was that in all examples, we always set:

  • all_pos_0 = n_past
  • all_pos_1 = 1
  • all_seq_id = 0

So I assume that 99% of the case, if user want to work with single-sequence (the most basic usage), then all_pos_1 and all_seq_id are redundant. The all_pos_0, as said earlier, can be tracked internally, so it's a bit redundant for now.

Then you could remove all the code in llama.cpp that has to deal with all_pos_0, all_pos_1, all_seq_id, single sequence users could continue using llama_batch_get_one without using the common library (which we should never encourage), and it would avoid a breaking API change.

The problem with such change is that even without touchingllama_batch_get_one, just removing all_pos_0, all_pos_1, all_seq_id is already a breaking change. That's because the shape of struct llama_batch will be changed.

It seems OK for me to keep llama_batch_get_one in the core library though. One idea is that it can produce a batch with just n_tokens, token being set and pos can be tracked internally. This will cover 99% of basic usage case (single-seq), and for the 1% use case where user still want single-seq but with custom token positions, they now need to construct the batch themself.

In any cases, I still strongly prefer to remove all_pos_0, all_pos_1, all_seq_id altogether, because it's no longer a recommended usage.

@slaren
Copy link
Collaborator

slaren commented Oct 5, 2024

Sounds goods to me. Other than causing an ABI break, removing all_pos_0, all_pos_1, all_seq_id will probably not break much code. The important API to keep is llama_batch_get_one.

@ngxson ngxson force-pushed the xsn/llama_batch_remove_compat branch from 697a3f9 to 1c48616 Compare October 11, 2024 10:11
Comment on lines 234 to 240
// - pos : the positions of the respective token in the sequence
// (if set to NULL, the token position will be tracked automatically by llama_decode)
// - seq_id : the sequence to which the respective token belongs
// (if set to NULL, the sequence ID will be assumed to be 0)
// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
// (if set to NULL, only the logits for last token will be returned)
//
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@slaren @ggerganov I updated the behavior of llama_batch to adapt to the removal of all_pos_0, all_pos_1, all_seq_id, please let me know what you think about this implementation. Thank you!

@@ -221,7 +221,7 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str.c_str());
result2 += next_token_str;

if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1, n_past, 1))) {
if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1))) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will generate a batch for seq_id == 0 and it needs to be seq_id == 1

make -j && ./llama-save-load-state -m ${some_model}

Copy link
Collaborator Author

@ngxson ngxson Oct 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for spotting that! Fixed in 6395174

@ngxson ngxson changed the title llama : move llama_batch_get_one from core library to common llama : remove all_pos_0, all_pos_1, all_seq_id from llama_batch Oct 11, 2024
@@ -412,13 +412,22 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);

llama_batch batch = llama_batch_init(batch_size, 0, 1);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move the llama_batch outside the loop and reuse it. Maybe utilize the common_batch_ API to make it little less cumbersome.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 734f9e2 and 4be7ecf

src/llama.cpp Outdated
batch.n_seq_id = n_seq_id.data();
}
if (!batch.seq_id) {
seq_id.resize(batch.n_tokens);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this also NULL terminated for consistency (see llama_batch_init):

Suggested change
seq_id.resize(batch.n_tokens);
seq_id.resize(batch.n_tokens + 1);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 7264596

@@ -376,7 +376,7 @@ int main(int argc, char ** argv) {
n_past, n_left, n_ctx, params.n_keep, n_discard);

llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past + 1, -n_discard);
Copy link
Collaborator Author

@ngxson ngxson Oct 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small explanation for what's happening: We suppose to shift all tokens from n_keep + n_discard + 1, so the end of must be n_past + 1 (or we can simply set it to -1, which means [p0, inf))

Copy link
Owner

@ggerganov ggerganov Oct 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I don't think n_past + 1 is needed here. There shouldn't be a token with pos == n_past in the KV cache.

But yes, using either n_past or -1 would achieve the same thing. Think using n_past is more illustrative.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok thanks, I figured out that I counted the token from 1, not from 0. I fixed that in 5d99ae4

@ngxson ngxson mentioned this pull request Oct 15, 2024
@ngxson ngxson requested a review from ggerganov October 18, 2024 13:57
@ngxson ngxson merged commit cda0e4b into ggerganov:master Oct 18, 2024
53 checks passed
dsx1986 pushed a commit to dsx1986/llama.cpp that referenced this pull request Oct 29, 2024
…rganov#9745)

* refactor llama_batch_get_one

* adapt all examples

* fix simple.cpp

* fix llama_bench

* fix

* fix context shifting

* free batch before return

* use common_batch_add, reuse llama_batch in loop

* null terminated seq_id list

* fix save-load-state example

* fix perplexity

* correct token pos in llama_batch_allocr
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Nov 15, 2024
…rganov#9745)

* refactor llama_batch_get_one

* adapt all examples

* fix simple.cpp

* fix llama_bench

* fix

* fix context shifting

* free batch before return

* use common_batch_add, reuse llama_batch in loop

* null terminated seq_id list

* fix save-load-state example

* fix perplexity

* correct token pos in llama_batch_allocr
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Nov 18, 2024
…rganov#9745)

* refactor llama_batch_get_one

* adapt all examples

* fix simple.cpp

* fix llama_bench

* fix

* fix context shifting

* free batch before return

* use common_batch_add, reuse llama_batch in loop

* null terminated seq_id list

* fix save-load-state example

* fix perplexity

* correct token pos in llama_batch_allocr
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
android Issues specific to Android breaking change Changes that break ABIs, APIs, file formats, or other forms of backwards compatibility. examples server
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants