Skip to content

Commit

Permalink
whisper : improve handling of prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Mar 21, 2024
1 parent 48a1452 commit 5c2c07d
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
2 changes: 1 addition & 1 deletion examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str());
fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false");
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str());
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str());
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
Expand Down
9 changes: 7 additions & 2 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3721,7 +3721,7 @@ int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_to

if (n_max_tokens < (int) res.size()) {
WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens);
return -1;
return -(int) res.size();
}

for (int i = 0; i < (int) res.size(); i++) {
Expand Down Expand Up @@ -5313,7 +5313,12 @@ int whisper_full_with_state(
// initial prompt
if (!params.prompt_tokens && params.initial_prompt) {
prompt_tokens.resize(1024);
prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
int n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
if (n_needed < 0) {
prompt_tokens.resize(-n_needed);
n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size());
}
prompt_tokens.resize(n_needed);
params.prompt_tokens = prompt_tokens.data();
params.prompt_n_tokens = prompt_tokens.size();
}
Expand Down
4 changes: 3 additions & 1 deletion whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ extern "C" {
// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens
// Returns -1 on failure
// Returns a negative number on failure - the number of tokens that would have been returned
// TODO: not sure if correct
WHISPER_API int whisper_tokenize(
struct whisper_context * ctx,
Expand Down Expand Up @@ -503,6 +503,8 @@ extern "C" {

// tokens to provide to the whisper decoder as initial prompt
// these are prepended to any existing text context from a previous call
// use whisper_tokenize() to convert text to tokens
// maximum of whisper_n_text_ctx()/2 tokens are used
const char * initial_prompt;
const whisper_token * prompt_tokens;
int prompt_n_tokens;
Expand Down

0 comments on commit 5c2c07d

Please sign in to comment.