From 2e3b4f62378082680f47b58080f771439a31075f Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Sun, 3 Dec 2023 04:23:14 -0600 Subject: [PATCH 1/7] Check the full vocab for grammar only if necessary --- common/sampling.cpp | 36 +++++++++++++++++++++++++++++++----- common/sampling.h | 9 +++++---- examples/infill/infill.cpp | 2 +- examples/main/main.cpp | 2 +- 4 files changed, 38 insertions(+), 11 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 1317024c2c11c..78092611b1995 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -100,10 +100,11 @@ std::string llama_sampling_print(const llama_sampling_params & params) { } llama_token llama_sampling_sample( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - struct llama_context * ctx_cfg, - const int idx) { + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + const int idx, + bool is_resampling) { // Add a parameter to indicate if we are resampling const llama_sampling_params & params = ctx_sampling->params; const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); @@ -128,7 +129,11 @@ llama_token llama_sampling_sample( llama_token id = 0; + // Get a pointer to the logits float * logits = llama_get_logits_ith(ctx_main, idx); + + // Make a copy of the original logits before any modifications + std::vector original_logits(logits, logits + llama_n_vocab(llama_get_model(ctx_main))); // apply params.logit_bias map for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { @@ -165,7 +170,8 @@ llama_token llama_sampling_sample( } } - if (ctx_sampling->grammar != NULL) { + // If we are in the resampling phase, apply grammar checks before sampling logic + if (is_resampling && ctx_sampling->grammar != NULL) { llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); } @@ -212,6 +218,26 @@ llama_token llama_sampling_sample( } } + if (ctx_sampling->grammar != NULL && !is_resampling) { + // Create an array with a single token data element for the sampled id + llama_token_data single_token_data = {id, logits[id], 0.0f}; + llama_token_data_array single_token_data_array = { &single_token_data, 1, false }; + + // Apply grammar constraints to the single token + llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar); + + // Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY + bool is_valid = single_token_data_array.data[0].logit != -INFINITY; + + // If the token is not valid according to the grammar, perform resampling + if (!is_valid) { + LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str()); + + // Recursively call llama_sampling_sample to resample with the grammar checks applied first + return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling + } + } + return id; } diff --git a/common/sampling.h b/common/sampling.h index 7c9b8dcf23bcb..5c387fb6f1af9 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -98,10 +98,11 @@ std::string llama_sampling_print(const llama_sampling_params & params); // - candidates: vector of candidate tokens // llama_token llama_sampling_sample( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - struct llama_context * ctx_cfg, - int idx = 0); + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + const int idx, + bool is_resampling = false); // Add the new parameter with default value void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 4a7827876e215..c4a38e5e2fbbb 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -527,7 +527,7 @@ int main(int argc, char ** argv) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) { - const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); + const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance, 0, false); llama_sampling_accept(ctx_sampling, ctx, id, true); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index c5cdfbf21b954..c67493dc6759d 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -630,7 +630,7 @@ int main(int argc, char ** argv) { LOG("saved session to %s\n", path_session.c_str()); } - const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); + const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance, 0, false); llama_sampling_accept(ctx_sampling, ctx, id, true); From 281e2bad8c0ab087dbf1f6f307f60e9173a371d7 Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Sun, 3 Dec 2023 05:25:44 -0600 Subject: [PATCH 2/7] Fix missing logit restoration step (?) Does this matter, actually? --- common/sampling.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 78092611b1995..d87340d2b803b 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -131,9 +131,14 @@ llama_token llama_sampling_sample( // Get a pointer to the logits float * logits = llama_get_logits_ith(ctx_main, idx); + + // Declare original_logits at the beginning of the function scope + std::vector original_logits; - // Make a copy of the original logits before any modifications - std::vector original_logits(logits, logits + llama_n_vocab(llama_get_model(ctx_main))); + if (!is_resampling) { + // Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this. + original_logits = std::vector(logits, logits + llama_n_vocab(llama_get_model(ctx_main))); + } // apply params.logit_bias map for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { @@ -233,6 +238,9 @@ llama_token llama_sampling_sample( if (!is_valid) { LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str()); + // Restore logits from the copy + std::copy(original_logits.begin(), original_logits.end(), logits); + // Recursively call llama_sampling_sample to resample with the grammar checks applied first return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling } From de454b9ef52000dce89eb7cd405e1757fbf8d9bc Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Sun, 3 Dec 2023 05:43:25 -0600 Subject: [PATCH 3/7] Fix whitespace / formatting --- common/sampling.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index d87340d2b803b..f5ac665124198 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -100,11 +100,11 @@ std::string llama_sampling_print(const llama_sampling_params & params) { } llama_token llama_sampling_sample( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - struct llama_context * ctx_cfg, - const int idx, - bool is_resampling) { // Add a parameter to indicate if we are resampling + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + const int idx, + bool is_resampling) { // Add a parameter to indicate if we are resampling const llama_sampling_params & params = ctx_sampling->params; const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); @@ -134,7 +134,7 @@ llama_token llama_sampling_sample( // Declare original_logits at the beginning of the function scope std::vector original_logits; - + if (!is_resampling) { // Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this. original_logits = std::vector(logits, logits + llama_n_vocab(llama_get_model(ctx_main))); From 245de1fc67905a9fceb6d7f050d1d066cf1d077e Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Sun, 3 Dec 2023 10:41:56 -0600 Subject: [PATCH 4/7] Adjust comment --- common/sampling.h | 3 +- grammars/extreme.gbnf | 117 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 1 deletion(-) create mode 100644 grammars/extreme.gbnf diff --git a/common/sampling.h b/common/sampling.h index 5c387fb6f1af9..1c1297f64f881 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -92,6 +92,7 @@ std::string llama_sampling_print(const llama_sampling_params & params); // optional: // - ctx_cfg: context to use for classifier-free guidance // - idx: sample from llama_get_logits_ith(ctx, idx) +// - is_resampling: determines whether or not this is a repeated sampling operation due to the ID not matching the grammar // // returns: // - token: sampled token @@ -102,7 +103,7 @@ llama_token llama_sampling_sample( struct llama_context * ctx_main, struct llama_context * ctx_cfg, const int idx, - bool is_resampling = false); // Add the new parameter with default value + bool is_resampling = false); void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, diff --git a/grammars/extreme.gbnf b/grammars/extreme.gbnf new file mode 100644 index 0000000000000..c267f68b953df --- /dev/null +++ b/grammars/extreme.gbnf @@ -0,0 +1,117 @@ +root ::= [ \t\n]* exp + +ws ::= [ \t\n]+ +w ::= [ \t]* + +comment ::= "#" [^#]* "#" [ \t]+ [\n]? [ \t]* + +### Expressions + +exp ::= comment* sequence-exp + +sequence-exp ::= tuple-exp (w ";" ws tuple-exp)* + +tuple-exp ::= cons-exp (w "," ws cons-exp)* + +cons-exp ::= binary-exp (w "::" w binary-exp)* + +binary-exp ::= unary-exp (ws binary-op ws unary-exp)* + +unary-exp ::= unary-op* function-app-exp + +function-app-exp ::= primary-exp (w "(" w exp w ")" w)* + +primary-exp ::= bool | + integer | + float | + string | + variable | + "()" | + "[]" | + constructor | + constructor-app | + parenthesized-exp | + list-exp | + let-exp | + if-exp | + case-exp | + test-exp | + type-alias | + fun + +constructor-app ::= constructor "(" w exp w ")" +parenthesized-exp ::= "(" w exp w ")" +list-exp ::= "[" exp ("," ws exp)* "]" +let-exp ::= "let" ws pat ws "=" ws exp ws "in" ws exp +if-exp ::= "if" ws exp ws "then" ws exp ws "else" ws exp +case-exp ::= "case" ws exp (ws "|" ws pat ws "=>" ws exp)+ ws "end" +test-exp ::= "test" ws exp ws "end" +type-alias ::= "type" ws constructor ws "=" ws typ ws "in" ws exp +fun ::= "fun" ws pat ws "->" ws exp + +type-variable ::= [a-z][A-Za-z0-9_]* +constructor ::= [A-Z][A-Za-z0-9_]* +variable ::= ([_a-bdg-hj-kn-qu-z][A-Za-z0-9_.]*)|(("s" ([.0-9A-Z_a-su-z][A-Za-z0-9_.]*)?)|("st" ([.0-9A-Z_a-qs-z][A-Za-z0-9_.]*)?)|("str" ([.0-9A-Z_a-tv-z][A-Za-z0-9_.]*)?)|("stru" ([.0-9A-Z_a-bd-z][A-Za-z0-9_.]*)?)|("struc" ([.0-9A-Z_a-su-z][A-Za-z0-9_.]*)?)|("struct" [A-Za-z0-9_.]+)|("c" ([.0-9A-Z_b-z][A-Za-z0-9_.]*)?)|("ca" ([.0-9A-Z_a-rt-z][A-Za-z0-9_.]*)?)|("cas" ([.0-9A-Z_a-df-z][A-Za-z0-9_.]*)?)|("case" [A-Za-z0-9_.]+)|("i" ([.0-9A-Z_a-mo-z][A-Za-z0-9_.]*)?)|("in" [A-Za-z0-9_.]+)|("r" ([.0-9A-Z_a-df-z][A-Za-z0-9_.]*)?)|("re" ([.0-9A-Z_a-bd-z][A-Za-z0-9_.]*)?)|("rec" [A-Za-z0-9_.]+)|("t" ([.0-9A-Z_a-df-z][A-Za-z0-9_.]*)?)|("te" ([.0-9A-Z_a-rt-z][A-Za-z0-9_.]*)?)|("tes" ([.0-9A-Z_a-su-z][A-Za-z0-9_.]*)?)|("test" [A-Za-z0-9_.]+)|("l" ([.0-9A-Z_a-df-z][A-Za-z0-9_.]*)?)|("le" ([.0-9A-Z_a-su-z][A-Za-z0-9_.]*)?)|("let" [A-Za-z0-9_.]+)|("m" ([.0-9A-Z_b-z][A-Za-z0-9_.]*)?)|("ma" ([.0-9A-Z_a-su-z][A-Za-z0-9_.]*)?)|("mat" ([.0-9A-Z_a-bd-z][A-Za-z0-9_.]*)?)|("matc" ([.0-9A-Z_a-gi-z][A-Za-z0-9_.]*)?)|("match" [A-Za-z0-9_.]+)|("f" ([.0-9A-Z_a-tv-z][A-Za-z0-9_.]*)?)|("fu" ([.0-9A-Z_a-mo-z][A-Za-z0-9_.]*)?)|("fun" [A-Za-z0-9_.]+)|("e" ([.0-9A-Z_a-mo-z][A-Za-z0-9_.]*)?)|("en" ([.0-9A-Z_a-ce-z][A-Za-z0-9_.]*)?)|("end" [A-Za-z0-9_.]+)) +bool ::= "true" | "false" +integer ::= [0-9]+ +float ::= [0-9]* "." [0-9]+ +string ::= "\"" [^"]* "\"" + +unary-op ::= "-" | "!" +binary-op-int ::= "+" | "-" | "*" | "/" | "<" | ">" | "<=" | ">=" | "==" | "!=" +binary-op-float ::= "+." | "-." | "*." | "/." | "<." | ">." | "<=." | ">=." | "==." | "!=." +binary-op-string ::= "$==" | "@" +binary-op-logic ::= "&&" +binary-op ::= binary-op-int | binary-op-float | binary-op-string | binary-op-logic + +### Patterns + +pat ::= type-ascription-pat + +type-ascription-pat ::= tuple-pat (w ":" ws typ)* + +tuple-pat ::= cons-pat (w "," ws cons-pat)* + +cons-pat ::= primary-pat (w "::" w primary-pat)* + +primary-pat ::= + bool | + integer | + float | + string | + variable | + "()" | + "[]" | + "_" | + constructor | + constructor-app-pat | + parenthesized-pat | + list-pat + +constructor-app-pat ::= constructor "(" w pat w ")" +parenthesized-pat ::= "(" w pat w ")" +list-pat ::= "[" pat (w "," ws pat)* "]" + +### Types + +typ ::= arrow-typ + +arrow-typ ::= tuple-typ (ws "->" ws tuple-typ)* + +tuple-typ ::= primary-typ (w "," ws primary-typ)* + +primary-typ ::= + "Unit" | + "Int" | + "Float" | + "Bool" | + "String" | + type-variable | + constructor | + constructor-def (ws "+" ws constructor-def)+ | + parenthesized-typ | + list-typ + +parenthesized-typ ::= "(" w typ w ")" +list-typ ::= "[" w typ w "]" +constructor-def ::= constructor | constructor "(" w typ w ")" \ No newline at end of file From f5f9d9620ba41edade36f06f62235286af40a35b Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Sun, 3 Dec 2023 10:55:37 -0600 Subject: [PATCH 5/7] Didn't mean to push test gbnf --- grammars/extreme.gbnf | 117 ------------------------------------------ 1 file changed, 117 deletions(-) delete mode 100644 grammars/extreme.gbnf diff --git a/grammars/extreme.gbnf b/grammars/extreme.gbnf deleted file mode 100644 index c267f68b953df..0000000000000 --- a/grammars/extreme.gbnf +++ /dev/null @@ -1,117 +0,0 @@ -root ::= [ \t\n]* exp - -ws ::= [ \t\n]+ -w ::= [ \t]* - -comment ::= "#" [^#]* "#" [ \t]+ [\n]? [ \t]* - -### Expressions - -exp ::= comment* sequence-exp - -sequence-exp ::= tuple-exp (w ";" ws tuple-exp)* - -tuple-exp ::= cons-exp (w "," ws cons-exp)* - -cons-exp ::= binary-exp (w "::" w binary-exp)* - -binary-exp ::= unary-exp (ws binary-op ws unary-exp)* - -unary-exp ::= unary-op* function-app-exp - -function-app-exp ::= primary-exp (w "(" w exp w ")" w)* - -primary-exp ::= bool | - integer | - float | - string | - variable | - "()" | - "[]" | - constructor | - constructor-app | - parenthesized-exp | - list-exp | - let-exp | - if-exp | - case-exp | - test-exp | - type-alias | - fun - -constructor-app ::= constructor "(" w exp w ")" -parenthesized-exp ::= "(" w exp w ")" -list-exp ::= "[" exp ("," ws exp)* "]" -let-exp ::= "let" ws pat ws "=" ws exp ws "in" ws exp -if-exp ::= "if" ws exp ws "then" ws exp ws "else" ws exp -case-exp ::= "case" ws exp (ws "|" ws pat ws "=>" ws exp)+ ws "end" -test-exp ::= "test" ws exp ws "end" -type-alias ::= "type" ws constructor ws "=" ws typ ws "in" ws exp -fun ::= "fun" ws pat ws "->" ws exp - -type-variable ::= [a-z][A-Za-z0-9_]* -constructor ::= [A-Z][A-Za-z0-9_]* -variable ::= ([_a-bdg-hj-kn-qu-z][A-Za-z0-9_.]*)|(("s" ([.0-9A-Z_a-su-z][A-Za-z0-9_.]*)?)|("st" ([.0-9A-Z_a-qs-z][A-Za-z0-9_.]*)?)|("str" ([.0-9A-Z_a-tv-z][A-Za-z0-9_.]*)?)|("stru" ([.0-9A-Z_a-bd-z][A-Za-z0-9_.]*)?)|("struc" ([.0-9A-Z_a-su-z][A-Za-z0-9_.]*)?)|("struct" [A-Za-z0-9_.]+)|("c" ([.0-9A-Z_b-z][A-Za-z0-9_.]*)?)|("ca" ([.0-9A-Z_a-rt-z][A-Za-z0-9_.]*)?)|("cas" ([.0-9A-Z_a-df-z][A-Za-z0-9_.]*)?)|("case" [A-Za-z0-9_.]+)|("i" ([.0-9A-Z_a-mo-z][A-Za-z0-9_.]*)?)|("in" [A-Za-z0-9_.]+)|("r" ([.0-9A-Z_a-df-z][A-Za-z0-9_.]*)?)|("re" ([.0-9A-Z_a-bd-z][A-Za-z0-9_.]*)?)|("rec" [A-Za-z0-9_.]+)|("t" ([.0-9A-Z_a-df-z][A-Za-z0-9_.]*)?)|("te" ([.0-9A-Z_a-rt-z][A-Za-z0-9_.]*)?)|("tes" ([.0-9A-Z_a-su-z][A-Za-z0-9_.]*)?)|("test" [A-Za-z0-9_.]+)|("l" ([.0-9A-Z_a-df-z][A-Za-z0-9_.]*)?)|("le" ([.0-9A-Z_a-su-z][A-Za-z0-9_.]*)?)|("let" [A-Za-z0-9_.]+)|("m" ([.0-9A-Z_b-z][A-Za-z0-9_.]*)?)|("ma" ([.0-9A-Z_a-su-z][A-Za-z0-9_.]*)?)|("mat" ([.0-9A-Z_a-bd-z][A-Za-z0-9_.]*)?)|("matc" ([.0-9A-Z_a-gi-z][A-Za-z0-9_.]*)?)|("match" [A-Za-z0-9_.]+)|("f" ([.0-9A-Z_a-tv-z][A-Za-z0-9_.]*)?)|("fu" ([.0-9A-Z_a-mo-z][A-Za-z0-9_.]*)?)|("fun" [A-Za-z0-9_.]+)|("e" ([.0-9A-Z_a-mo-z][A-Za-z0-9_.]*)?)|("en" ([.0-9A-Z_a-ce-z][A-Za-z0-9_.]*)?)|("end" [A-Za-z0-9_.]+)) -bool ::= "true" | "false" -integer ::= [0-9]+ -float ::= [0-9]* "." [0-9]+ -string ::= "\"" [^"]* "\"" - -unary-op ::= "-" | "!" -binary-op-int ::= "+" | "-" | "*" | "/" | "<" | ">" | "<=" | ">=" | "==" | "!=" -binary-op-float ::= "+." | "-." | "*." | "/." | "<." | ">." | "<=." | ">=." | "==." | "!=." -binary-op-string ::= "$==" | "@" -binary-op-logic ::= "&&" -binary-op ::= binary-op-int | binary-op-float | binary-op-string | binary-op-logic - -### Patterns - -pat ::= type-ascription-pat - -type-ascription-pat ::= tuple-pat (w ":" ws typ)* - -tuple-pat ::= cons-pat (w "," ws cons-pat)* - -cons-pat ::= primary-pat (w "::" w primary-pat)* - -primary-pat ::= - bool | - integer | - float | - string | - variable | - "()" | - "[]" | - "_" | - constructor | - constructor-app-pat | - parenthesized-pat | - list-pat - -constructor-app-pat ::= constructor "(" w pat w ")" -parenthesized-pat ::= "(" w pat w ")" -list-pat ::= "[" pat (w "," ws pat)* "]" - -### Types - -typ ::= arrow-typ - -arrow-typ ::= tuple-typ (ws "->" ws tuple-typ)* - -tuple-typ ::= primary-typ (w "," ws primary-typ)* - -primary-typ ::= - "Unit" | - "Int" | - "Float" | - "Bool" | - "String" | - type-variable | - constructor | - constructor-def (ws "+" ws constructor-def)+ | - parenthesized-typ | - list-typ - -parenthesized-typ ::= "(" w typ w ")" -list-typ ::= "[" w typ w "]" -constructor-def ::= constructor | constructor "(" w typ w ")" \ No newline at end of file From 115a9218ebc47fc3bc9542a63cee6e10c2e8e1ae Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Thu, 7 Dec 2023 14:58:48 -0600 Subject: [PATCH 6/7] Split sampling into the helper function (?) And also revert the changes made to the header --- common/sampling.cpp | 14 +++++++++++--- common/sampling.h | 12 +++++------- examples/infill/infill.cpp | 2 +- examples/main/main.cpp | 2 +- 4 files changed, 18 insertions(+), 12 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index f5ac665124198..6344c29dae0c4 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -99,7 +99,7 @@ std::string llama_sampling_print(const llama_sampling_params & params) { return std::string(result); } -llama_token llama_sampling_sample( +static llama_token llama_sampling_sample_impl( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, struct llama_context * ctx_cfg, @@ -241,14 +241,22 @@ llama_token llama_sampling_sample( // Restore logits from the copy std::copy(original_logits.begin(), original_logits.end(), logits); - // Recursively call llama_sampling_sample to resample with the grammar checks applied first - return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling + return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling } } return id; } +llama_token llama_sampling_sample( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + const int idx) { + // Call the implementation function with is_resampling set to false by default + return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false); +} + void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, diff --git a/common/sampling.h b/common/sampling.h index 1c1297f64f881..4a8c522b67b89 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -92,21 +92,19 @@ std::string llama_sampling_print(const llama_sampling_params & params); // optional: // - ctx_cfg: context to use for classifier-free guidance // - idx: sample from llama_get_logits_ith(ctx, idx) -// - is_resampling: determines whether or not this is a repeated sampling operation due to the ID not matching the grammar // // returns: // - token: sampled token // - candidates: vector of candidate tokens // llama_token llama_sampling_sample( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - struct llama_context * ctx_cfg, - const int idx, - bool is_resampling = false); + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + int idx = 0); void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, llama_token id, - bool apply_grammar); + bool apply_grammar); \ No newline at end of file diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index c4a38e5e2fbbb..4a7827876e215 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -527,7 +527,7 @@ int main(int argc, char ** argv) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) { - const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance, 0, false); + const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); llama_sampling_accept(ctx_sampling, ctx, id, true); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index c67493dc6759d..c5cdfbf21b954 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -630,7 +630,7 @@ int main(int argc, char ** argv) { LOG("saved session to %s\n", path_session.c_str()); } - const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance, 0, false); + const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); llama_sampling_accept(ctx_sampling, ctx, id, true); From 88fd22c3fc0469c8b1157df10d6e9cbd5c5cc601 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 23 Dec 2023 11:25:34 +0200 Subject: [PATCH 7/7] common : fix final newline --- common/sampling.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/sampling.h b/common/sampling.h index 7dbb865aca0d4..fdfa9eed1467b 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -111,4 +111,4 @@ void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, llama_token id, - bool apply_grammar); \ No newline at end of file + bool apply_grammar);