From bf14ca7dc523fa9ab1a77cdee1cb36ff412a136a Mon Sep 17 00:00:00 2001 From: Christian Fillion Date: Fri, 7 Feb 2025 04:33:27 -0500 Subject: [PATCH] llama : add llama_sampler_init for safe usage of llama_sampler_free (#11727) The C API in llama.h claims users can implement `llama_sampler_i` to create custom `llama_sampler`. The sampler chain takes ownership and calls `llama_sampler_free` on them. However, `llama_sampler_free` is hard-coded to use `delete`. This is undefined behavior if the object wasn't also allocated via `new` from libllama's C++ runtime. Callers in C and C-compatible languages do not use C++'s `new` operator. C++ callers may not be sharing the same heap as libllama. --- common/llguidance.cpp | 6 +- include/llama.h | 5 +- src/llama-sampling.cpp | 121 ++++++++++++++++++++++------------------- 3 files changed, 70 insertions(+), 62 deletions(-) diff --git a/common/llguidance.cpp b/common/llguidance.cpp index 7aa8ddd80297b..2feeb93c87e30 100644 --- a/common/llguidance.cpp +++ b/common/llguidance.cpp @@ -254,10 +254,10 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * g }; } - return new llama_sampler{ + return llama_sampler_init( /* .iface = */ &llama_sampler_llg_i, - /* .ctx = */ ctx, - }; + /* .ctx = */ ctx + ); } #else diff --git a/include/llama.h b/include/llama.h index 61907ed404dbf..3784f7d3950e5 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1114,11 +1114,12 @@ extern "C" { }; struct llama_sampler { - struct llama_sampler_i * iface; - llama_sampler_context_t ctx; + const struct llama_sampler_i * iface; + llama_sampler_context_t ctx; }; // mirror of llama_sampler_i: + LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx); LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 26974f5396565..990b6129746de 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -316,6 +316,13 @@ static uint32_t get_rng_seed(uint32_t seed) { // llama_sampler API +struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) { + return new llama_sampler { + /* .iface = */ iface, + /* .ctx = */ ctx, + }; +} + const char * llama_sampler_name(const struct llama_sampler * smpl) { if (!smpl->iface) { return "(null)"; @@ -347,10 +354,10 @@ struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) { } if (smpl->ctx == nullptr) { - return new llama_sampler { + return llama_sampler_init( /* .iface = */ smpl->iface, - /* .ctx = */ nullptr, - }; + /* .ctx = */ nullptr + ); } GGML_ABORT("the sampler does not support cloning"); @@ -472,15 +479,15 @@ static struct llama_sampler_i llama_sampler_chain_i = { }; struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { - return new llama_sampler { + return llama_sampler_init( /* .iface = */ &llama_sampler_chain_i, /* .ctx = */ new llama_sampler_chain { /* .params = */ params, /* .samplers = */ {}, /* .t_sample_us = */ 0, /* .n_sample = */ 0, - }, - }; + } + ); } void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) { @@ -546,10 +553,10 @@ static struct llama_sampler_i llama_sampler_greedy_i = { }; struct llama_sampler * llama_sampler_init_greedy() { - return new llama_sampler { + return llama_sampler_init( /* .iface = */ &llama_sampler_greedy_i, - /* .ctx = */ nullptr, - }; + /* .ctx = */ nullptr + ); } // dist @@ -608,14 +615,14 @@ static struct llama_sampler_i llama_sampler_dist_i = { struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { auto seed_cur = get_rng_seed(seed); - return new llama_sampler { + return llama_sampler_init( /* .iface = */ &llama_sampler_dist_i, /* .ctx = */ new llama_sampler_dist { /* .seed = */ seed, /* .seed_cur = */ seed_cur, /* .rng = */ std::mt19937(seed_cur), - }, - }; + } + ); } // softmax @@ -638,10 +645,10 @@ static struct llama_sampler_i llama_sampler_softmax_i = { }; struct llama_sampler * llama_sampler_init_softmax() { - return new llama_sampler { + return llama_sampler_init( /* .iface = */ &llama_sampler_softmax_i, - /* .ctx = */ nullptr, - }; + /* .ctx = */ nullptr + ); } // top-k @@ -678,12 +685,12 @@ static struct llama_sampler_i llama_sampler_top_k_i = { }; struct llama_sampler * llama_sampler_init_top_k(int32_t k) { - return new llama_sampler { + return llama_sampler_init( /* .iface = */ &llama_sampler_top_k_i, /* .ctx = */ new llama_sampler_top_k { /* .k = */ k, - }, - }; + } + ); } // top-p @@ -744,13 +751,13 @@ static struct llama_sampler_i llama_sampler_top_p_i = { }; struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { - return new llama_sampler { + return llama_sampler_init( /* .iface = */ &llama_sampler_top_p_i, /* .ctx = */ new llama_sampler_top_p { /* .p = */ p, /* .min_keep = */ min_keep, - }, - }; + } + ); } // min-p @@ -840,13 +847,13 @@ static struct llama_sampler_i llama_sampler_min_p_i = { }; struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { - return new llama_sampler { + return llama_sampler_init( /* .iface = */ &llama_sampler_min_p_i, /* .ctx = */ new llama_sampler_min_p { /* .p = */ p, /* .min_keep = */ min_keep, - }, - }; + } + ); } // typical @@ -939,13 +946,13 @@ static struct llama_sampler_i llama_sampler_typical_i = { }; struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { - return new llama_sampler { + return llama_sampler_init( /* .iface = */ &llama_sampler_typical_i, /* .ctx = */ new llama_sampler_typical { /* .p = */ p, /* .min_keep = */ min_keep, - }, - }; + } + ); } // temp @@ -983,12 +990,12 @@ static struct llama_sampler_i llama_sampler_temp_i = { }; struct llama_sampler * llama_sampler_init_temp(float temp) { - return new llama_sampler { + return llama_sampler_init( /* .iface = */ &llama_sampler_temp_i, /* .ctx = */ new llama_sampler_temp { /*.temp = */ temp, - }, - }; + } + ); } // temp-ext @@ -1093,14 +1100,14 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = { }; struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { - return new llama_sampler { + return llama_sampler_init( /* .iface = */ &llama_sampler_temp_ext_i, /* .ctx = */ new llama_sampler_temp_ext { /* .temp = */ temp, /* .delta = */ delta, /* .exponent = */ exponent, - }, - }; + } + ); } // xtc @@ -1185,7 +1192,7 @@ static struct llama_sampler_i llama_sampler_xtc_i = { struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) { auto seed_cur = get_rng_seed(seed); - return new llama_sampler { + return llama_sampler_init( /* .iface = */ &llama_sampler_xtc_i, /* .ctx = */ new llama_sampler_xtc { /* .probability = */ p, @@ -1194,8 +1201,8 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, /* .seed = */ seed, /* .seed_cur = */ seed_cur, /* .rng = */ std::mt19937(seed_cur), - }, - }; + } + ); } // mirostat @@ -1292,7 +1299,7 @@ static struct llama_sampler_i llama_sampler_mirostat_i = { struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { auto seed_cur = get_rng_seed(seed); - return new llama_sampler { + return llama_sampler_init( /* .iface = */ &llama_sampler_mirostat_i, /* .ctx = */ new llama_sampler_mirostat { /* .n_vocab = */ n_vocab, @@ -1303,8 +1310,8 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see /* .m = */ m, /* .mu = */ 2.0f*tau, /* .rng = */ std::mt19937(seed_cur), - }, - }; + } + ); } // mirostat v2 @@ -1391,7 +1398,7 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = { struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { auto seed_cur = get_rng_seed(seed); - return new llama_sampler { + return llama_sampler_init( /* .iface = */ &llama_sampler_mirostat_v2_i, /* .ctx = */ new llama_sampler_mirostat_v2 { /* .seed = */ seed, @@ -1400,8 +1407,8 @@ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, /* .eta = */ eta, /* .mu = */ 2.0f*tau, /* .rng = */ std::mt19937(seed_cur), - }, - }; + } + ); } // grammar @@ -1528,10 +1535,10 @@ static struct llama_sampler * llama_sampler_init_grammar_impl( }; } - return new llama_sampler { + return llama_sampler_init( /* .iface = */ &llama_sampler_grammar_i, - /* .ctx = */ ctx, - }; + /* .ctx = */ ctx + ); } struct llama_sampler * llama_sampler_init_grammar( @@ -1678,7 +1685,7 @@ struct llama_sampler * llama_sampler_init_penalties( float penalty_present) { penalty_last_n = std::max(penalty_last_n, 0); - return new llama_sampler { + return llama_sampler_init( /* .iface = */ &llama_sampler_penalties_i, /* .ctx = */ new llama_sampler_penalties { /* .penalty_last_n = */ penalty_last_n, @@ -1687,8 +1694,8 @@ struct llama_sampler * llama_sampler_init_penalties( /* .penalty_present = */ penalty_present, /* .prev = */ ring_buffer(penalty_last_n), /* .token_count = */ {}, - }, - }; + } + ); } // DRY @@ -2041,7 +2048,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, } } - return new llama_sampler { + return llama_sampler_init( /* .iface = */ &llama_sampler_dry_i, /* .ctx = */ new llama_sampler_dry { /* .total_context_size = */ context_size, @@ -2053,8 +2060,8 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, /* .dry_repeat_count = */ dry_enabled ? std::vector(effective_dry_penalty_last_n, 0) : std::vector{}, /* .dry_max_token_repeat = */ {}, /* .last_tokens = */ dry_enabled ? ring_buffer(effective_dry_penalty_last_n) : ring_buffer(0), - }, - }; + } + ); } // wrapper for test-sampling.cpp @@ -2155,14 +2162,14 @@ struct llama_sampler * llama_sampler_init_logit_bias( int32_t n_vocab, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { - return new llama_sampler { + return llama_sampler_init( /* .iface = */ &llama_sampler_logit_bias_i, /* .ctx = */ new llama_sampler_logit_bias { /* .n_vocab = */ n_vocab, /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), /* .to_search = */ {}, - }, - }; + } + ); } // infill @@ -2377,14 +2384,14 @@ static struct llama_sampler_i llama_sampler_infill_i = { }; struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) { - return new llama_sampler { + return llama_sampler_init( /* .iface = */ &llama_sampler_infill_i, /* .ctx = */ new llama_sampler_infill { /* .vocab = */ vocab, /* .buf0 = */ std::vector(512), /* .buf1 = */ std::vector(512), - }, - }; + } + ); } // utils