Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add beam search #631

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Add beam search #631

wants to merge 2 commits into from

Conversation

mattpulver
Copy link

@mattpulver mattpulver commented Aug 23, 2023

Invoke by adding "beam_width": 2 (for example) to /v1/completions POST.

This PR will be moved out of Draft mode after ggerganov/llama.cpp#2267 is merged. Closes #145 Closes #340 Closes #185

In the meantime, there is a question:

How does one specify --logits_all False when invoked from the command line?

python3 -m llama_cpp.server --logits_all False

results in settings.logits_all=True on startup.

@koskoakos
Copy link

koskoakos commented Sep 4, 2023

results in settings.logits_all=True on startup.

you can pass False with an empty string --logits_all '' but a better way would be to add either action='store_true' or action=argparse.BooleanOptionalAction when adding argument

@Avinash-Raj
Copy link

@mattpulver Guess now you can un-draft this PR.

@mattpulver mattpulver marked this pull request as ready for review September 19, 2023 14:09
@mattpulver
Copy link
Author

you can pass False with an empty string --logits_all '' but a better way would be to add either action='store_true' or action=argparse.BooleanOptionalAction when adding argument

When I add --logits_all '' to python -m llama_cpp.server it errors w/

__main__.py: error: argument --logits_all: invalid parse_bool_arg value: ''

In the meantime I changed the default setting for logits_all from True to False: 3fce944
but I welcome any better suggestions. (I'm not sure where exactly to make the action='store_true' suggestion.)

@mattpulver
Copy link
Author

Testing

  • Start the web server.
  • Go to http://localhost:8000/docs (adjust port as needed)
  • Open/click the POST /v1/completions panel.
  • Press the Try it out button.
  • Edit the Example Value json by adding "beam_width": 2,
  • Press Execute

If you would like to see how the beams evolve and their probabilities, uncomment:

    #print(f"\n\nCurrent beams (last_call={beams_state.last_call}):\n")
    #for i in range(beams_state.n_beams):
    #    print(f"beams[{i}]", beam_view_to_string(callback_data.ctx,beams_state.beam_views[i]))

in llama_cpp/llama.py.

@mattpulver
Copy link
Author

Resolves #184

@mattpulver mattpulver changed the title Add beam search. Invoke by adding "beam_search": 2 (for example) to /v1/completions POST. Add beam search. Invoke by adding "beam_width": 2 (for example) to /v1/completions POST. Sep 19, 2023
@abetlen
Copy link
Owner

abetlen commented Sep 20, 2023

@mattpulver great work here, I'll review this and should have it merged this week.

Cheers

@mattpulver
Copy link
Author

@abetlen Thanks. Perhaps the most intrusive integration change is changing the default value of the logits_all command line parameter from True to False: 3fce944

This actually matches the default value of llama.cpp so in a broader sense this makes sense IMO but it may break some existing functionality.

@abetlen abetlen changed the title Add beam search. Invoke by adding "beam_width": 2 (for example) to /v1/completions POST. Add beam search. Closes #145 #340 #185 Sep 30, 2023
@abetlen abetlen changed the title Add beam search. Closes #145 #340 #185 Add beam search Sep 30, 2023
@abetlen
Copy link
Owner

abetlen commented Oct 5, 2023

@mattpulver just a quick update, I'm going to hold of merging this until after #771 because that's going to have some big impact on how we use the llama.cpp api internally in the Llama class. Once that's in I'll take a look at dealing with the merge conflicts here. One thing to note, I won't change the default behaviour of logits_all, this would constitute a breaking change for a number of users so we need to find a better solution (automatially reload the model with logits_all=False) or inform the user that it has to be set to false for beam search.

beam_search_dictionary = {}

# beam_search_callback() must flag beams when they reach end-of-sentence.
# TODO: Use stop_sequences.
Copy link
Contributor

@cebtenzzre cebtenzzre Oct 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this TODO the reason I'm seeing this in the debug output? Note the EOS and BOS. The prompt is not code-related, FWIW.

beams[0] p(0.493342787027359): <0x0A></s><s><0x0A>#include▁<iostream><0x0A>#include▁<cmath.h><0x0A>#include▁<c
beams[1] p(0.5066572427749634): <0x0A></s><s><0x0A>#include▁<iostream><0x0A>#include▁<cmath.h><0x0A>#include▁<vector

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you mean exactly by "Note the EOS and BOS."

The TODO note relates to the is_at_eob() function above. Currently, EOB (end-of-beam) is determined by the character llama_cpp.llama_token_eos(ctx). If EOB is to be generalized to user-defined EOB sequences, then this would be the function to add the logic to.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I mean is that </s> (EOS) is generated by the model, but the beam search keeps going (onto BOS, and then it starts making up something unrelated). I think this shouldn't happen.

Copy link
Author

@mattpulver mattpulver Oct 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. To answer your original question, yes, that is exactly what the TODO is talking about.

A good follow-up item would be to add stop_sequences to the class beam_search_callback_data and set them to custom stop sequences (e.g. </s>) when the class is instantiated below. Then pass it to is_at_eob() when called from beam_search_callback().

It may require a bit more logic to accommodate the possibility of stop sequences being split across separate tokens.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear, this is not a custom stop sequence, this is just the regular EOS token (AFAIK), which is rendered this way in the output. You say EOB is determined by llama_token_eos, but that doesn't seem to work for me.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One way to debug this is to modify the line above

 string += llama_cpp.llama_token_get_text(ctx, beam_view.tokens[i]).decode("utf-8")

to something that appends both the numeric token id beam_view.tokens[i] along with the decoded substring. If you're really encountering the llama_token_eos() token, last I checked, it should have token id 2.

Copy link
Contributor

@cebtenzzre cebtenzzre Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops! Sorry, this is one is my fault. llama_token_eos takes model now, not ctx, and I missed that when I merged this PR into my local branch. Unfortunately, ctypes gives no indication of pointer type mismatches. Normally I would use mypy, but it seems as though llama-cpp-python is not tested against it - there are many type errors and other complaints.

@cebtenzzre
Copy link
Contributor

cebtenzzre commented Oct 25, 2023

If I wait a little longer, I consistently hit GGML_ASSERT: llama-cpp-python/vendor/llama.cpp/llama.cpp:5967: n_tokens <= n_batch with this applied to the latest llama-cpp-python.

beam.tokens.size() ends up at 513 here, which is one more than the batch size: https://github.com/ggerganov/llama.cpp/blob/master/llama.cpp#L7849-L7852

// beam is not at end-of-sentence, so branch with next top_k tokens.
if (!beam.tokens.empty()) {
    llama_decode(ctx, llama_batch_get_one(beam.tokens.data(), beam.tokens.size(), n_past, 0));
}

backtrace (line numbers may not be accurate):

#4  0x00007f59ef28c43b in llama_decode_internal (lctx=..., batch=...) at /home/cebtenzzre/src/forks/llama-cpp-python/vendor/llama.cpp/llama.cpp:5970
#5  0x00007f59ef28c906 in llama_decode (ctx=<optimized out>, batch=<error reading variable: Cannot access memory at address 0x8>)
    at /home/cebtenzzre/src/forks/llama-cpp-python/vendor/llama.cpp/llama.cpp:9781
#6  0x00007f59ef2a45df in llama_beam_search_data::fill_next_beams_by_top_probabilities (this=this@entry=0x7f5a0e1faca0, beam=...)
    at /home/cebtenzzre/src/forks/llama-cpp-python/vendor/llama.cpp/llama.cpp:7951
#7  0x00007f59ef2a530f in llama_beam_search_data::loop (this=this@entry=0x7f5a0e1faca0, callback=callback@entry=0x7f5b2dfa3010, 
    callback_data=callback_data@entry=0x7f5a15c768d0) at /home/cebtenzzre/src/forks/llama-cpp-python/vendor/llama.cpp/llama.cpp:8032
#8  0x00007f59ef28e45c in llama_beam_search (ctx=0x7f59e9e34e30, callback=0x7f5b2dfa3010, callback_data=0x7f5a15c768d0, n_beams=<optimized out>, 
    n_past=<optimized out>, n_predict=<optimized out>) at /home/cebtenzzre/src/forks/llama-cpp-python/vendor/llama.cpp/llama.cpp:8072

I can reproduce it within 30 seconds or so with a 7B model on CUDA on commit 1a1c3dc. On earlier commits, it seems to just hang.

Fixing the EOS issue on my end does not resolve this.

@abetlen abetlen force-pushed the main branch 2 times, most recently from 8c93cf8 to cc0fe43 Compare November 14, 2023 20:24
@rishsriv
Copy link

rishsriv commented Jan 9, 2024

Thanks @mattpulver for the PR (both here and in the llama.cpp repo)

Curious if abetlen, cebtenzzre, or someone else knows if/when this will get merged? I've found beam search extremely helpful for code generation, and would love to know if it'll be supported with the main llama-cpp-python library in the near future, or if I should create a binary from Matt's PR instead.

@abetlen
Copy link
Owner

abetlen commented Jan 11, 2024

Hey @rishsriv I'm still planning to merge this however I'm currently grinding through the batch processing support first as it requires a bunch of internal refactoring, after that I was planning on coming back in and merging this. Can't give an eta though, likely in the next few weeks if I had to guess.

@rishsriv
Copy link

Got it – thank you for the response! If there are things you think external contributors will be able to fix, please do open an issue and will be happy to help fix it.

@cebtenzzre
Copy link
Contributor

If I wait a little longer, I consistently hit GGML_ASSERT: llama-cpp-python/vendor/llama.cpp/llama.cpp:5967: n_tokens <= n_batch with this applied to the latest llama-cpp-python.

Possibly related: ggerganov/llama.cpp#6664

@ExtReMLapin
Copy link
Contributor

Any update on this ?

@ExtReMLapin
Copy link
Contributor

Alright so I copied the changed into a local clone and it seems to be working, or running at least.

First it spend a crazy time on llama_cpp.llama_beam_search then the token output is quite low. which makes the sampling time insanely high.

beam_width=1

HEAD MASTER :

llama_print_timings:        load time =     154.96 ms
llama_print_timings:      sample time =    1534.47 ms /   548 runs   (    2.80 ms per token,   357.13 tokens per second)
llama_print_timings: prompt eval time =     371.38 ms /  1661 tokens (    0.22 ms per token,  4472.54 tokens per second)
llama_print_timings:        eval time =    4816.73 ms /   547 runs   (    8.81 ms per token,   113.56 tokens per second)

This fork :

llama_print_timings:        load time =     127.43 ms
llama_print_timings:      sample time =   59429.62 ms /     1 runs   (59429.62 ms per token,     0.02 tokens per second)
llama_print_timings: prompt eval time =     338.67 ms /  1661 tokens (    0.20 ms per token,  4904.55 tokens per second)
llama_print_timings:        eval time =   58430.05 ms /  6530 runs   (    8.95 ms per token,   111.76 tokens per second)

59429.62 ms per token

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
7 participants