Skip to content

Commit

Permalink
llama : add llama_sampler_init for safe usage of llama_sampler_free (g…
Browse files Browse the repository at this point in the history
…gml-org#11727)

The C API in llama.h claims users can implement `llama_sampler_i` to
create custom `llama_sampler`. The sampler chain takes ownership and
calls `llama_sampler_free` on them. However, `llama_sampler_free` is
hard-coded to use `delete`. This is undefined behavior if the object
wasn't also allocated via `new` from libllama's C++ runtime. Callers
in C and C-compatible languages do not use C++'s `new` operator. C++
callers may not be sharing the same heap as libllama.
  • Loading branch information
cfillion authored and NeoZhangJianyu committed Feb 25, 2025
1 parent cf7f1c1 commit bf14ca7
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 62 deletions.
6 changes: 3 additions & 3 deletions common/llguidance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,10 @@ llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * g
};
}

return new llama_sampler{
return llama_sampler_init(
/* .iface = */ &llama_sampler_llg_i,
/* .ctx = */ ctx,
};
/* .ctx = */ ctx
);
}

#else
Expand Down
5 changes: 3 additions & 2 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1114,11 +1114,12 @@ extern "C" {
};

struct llama_sampler {
struct llama_sampler_i * iface;
llama_sampler_context_t ctx;
const struct llama_sampler_i * iface;
llama_sampler_context_t ctx;
};

// mirror of llama_sampler_i:
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
Expand Down
121 changes: 64 additions & 57 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,13 @@ static uint32_t get_rng_seed(uint32_t seed) {

// llama_sampler API

struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) {
return new llama_sampler {
/* .iface = */ iface,
/* .ctx = */ ctx,
};
}

const char * llama_sampler_name(const struct llama_sampler * smpl) {
if (!smpl->iface) {
return "(null)";
Expand Down Expand Up @@ -347,10 +354,10 @@ struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
}

if (smpl->ctx == nullptr) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ smpl->iface,
/* .ctx = */ nullptr,
};
/* .ctx = */ nullptr
);
}

GGML_ABORT("the sampler does not support cloning");
Expand Down Expand Up @@ -472,15 +479,15 @@ static struct llama_sampler_i llama_sampler_chain_i = {
};

struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_chain_i,
/* .ctx = */ new llama_sampler_chain {
/* .params = */ params,
/* .samplers = */ {},
/* .t_sample_us = */ 0,
/* .n_sample = */ 0,
},
};
}
);
}

void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
Expand Down Expand Up @@ -546,10 +553,10 @@ static struct llama_sampler_i llama_sampler_greedy_i = {
};

struct llama_sampler * llama_sampler_init_greedy() {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_greedy_i,
/* .ctx = */ nullptr,
};
/* .ctx = */ nullptr
);
}

// dist
Expand Down Expand Up @@ -608,14 +615,14 @@ static struct llama_sampler_i llama_sampler_dist_i = {

struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_dist_i,
/* .ctx = */ new llama_sampler_dist {
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
},
};
}
);
}

// softmax
Expand All @@ -638,10 +645,10 @@ static struct llama_sampler_i llama_sampler_softmax_i = {
};

struct llama_sampler * llama_sampler_init_softmax() {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_softmax_i,
/* .ctx = */ nullptr,
};
/* .ctx = */ nullptr
);
}

// top-k
Expand Down Expand Up @@ -678,12 +685,12 @@ static struct llama_sampler_i llama_sampler_top_k_i = {
};

struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_top_k_i,
/* .ctx = */ new llama_sampler_top_k {
/* .k = */ k,
},
};
}
);
}

// top-p
Expand Down Expand Up @@ -744,13 +751,13 @@ static struct llama_sampler_i llama_sampler_top_p_i = {
};

struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_top_p_i,
/* .ctx = */ new llama_sampler_top_p {
/* .p = */ p,
/* .min_keep = */ min_keep,
},
};
}
);
}

// min-p
Expand Down Expand Up @@ -840,13 +847,13 @@ static struct llama_sampler_i llama_sampler_min_p_i = {
};

struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_min_p_i,
/* .ctx = */ new llama_sampler_min_p {
/* .p = */ p,
/* .min_keep = */ min_keep,
},
};
}
);
}

// typical
Expand Down Expand Up @@ -939,13 +946,13 @@ static struct llama_sampler_i llama_sampler_typical_i = {
};

struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_typical_i,
/* .ctx = */ new llama_sampler_typical {
/* .p = */ p,
/* .min_keep = */ min_keep,
},
};
}
);
}

// temp
Expand Down Expand Up @@ -983,12 +990,12 @@ static struct llama_sampler_i llama_sampler_temp_i = {
};

struct llama_sampler * llama_sampler_init_temp(float temp) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_temp_i,
/* .ctx = */ new llama_sampler_temp {
/*.temp = */ temp,
},
};
}
);
}

// temp-ext
Expand Down Expand Up @@ -1093,14 +1100,14 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
};

struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_temp_ext_i,
/* .ctx = */ new llama_sampler_temp_ext {
/* .temp = */ temp,
/* .delta = */ delta,
/* .exponent = */ exponent,
},
};
}
);
}

// xtc
Expand Down Expand Up @@ -1185,7 +1192,7 @@ static struct llama_sampler_i llama_sampler_xtc_i = {

struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_xtc_i,
/* .ctx = */ new llama_sampler_xtc {
/* .probability = */ p,
Expand All @@ -1194,8 +1201,8 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
},
};
}
);
}

// mirostat
Expand Down Expand Up @@ -1292,7 +1299,7 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {

struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_mirostat_i,
/* .ctx = */ new llama_sampler_mirostat {
/* .n_vocab = */ n_vocab,
Expand All @@ -1303,8 +1310,8 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see
/* .m = */ m,
/* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed_cur),
},
};
}
);
}

// mirostat v2
Expand Down Expand Up @@ -1391,7 +1398,7 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {

struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_mirostat_v2_i,
/* .ctx = */ new llama_sampler_mirostat_v2 {
/* .seed = */ seed,
Expand All @@ -1400,8 +1407,8 @@ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau,
/* .eta = */ eta,
/* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed_cur),
},
};
}
);
}

// grammar
Expand Down Expand Up @@ -1528,10 +1535,10 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
};
}

return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_grammar_i,
/* .ctx = */ ctx,
};
/* .ctx = */ ctx
);
}

struct llama_sampler * llama_sampler_init_grammar(
Expand Down Expand Up @@ -1678,7 +1685,7 @@ struct llama_sampler * llama_sampler_init_penalties(
float penalty_present) {
penalty_last_n = std::max(penalty_last_n, 0);

return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_penalties_i,
/* .ctx = */ new llama_sampler_penalties {
/* .penalty_last_n = */ penalty_last_n,
Expand All @@ -1687,8 +1694,8 @@ struct llama_sampler * llama_sampler_init_penalties(
/* .penalty_present = */ penalty_present,
/* .prev = */ ring_buffer<llama_token>(penalty_last_n),
/* .token_count = */ {},
},
};
}
);
}

// DRY
Expand Down Expand Up @@ -2041,7 +2048,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
}
}

return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_dry_i,
/* .ctx = */ new llama_sampler_dry {
/* .total_context_size = */ context_size,
Expand All @@ -2053,8 +2060,8 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
/* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
/* .dry_max_token_repeat = */ {},
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
},
};
}
);
}

// wrapper for test-sampling.cpp
Expand Down Expand Up @@ -2155,14 +2162,14 @@ struct llama_sampler * llama_sampler_init_logit_bias(
int32_t n_vocab,
int32_t n_logit_bias,
const llama_logit_bias * logit_bias) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_logit_bias_i,
/* .ctx = */ new llama_sampler_logit_bias {
/* .n_vocab = */ n_vocab,
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
/* .to_search = */ {},
},
};
}
);
}

// infill
Expand Down Expand Up @@ -2377,14 +2384,14 @@ static struct llama_sampler_i llama_sampler_infill_i = {
};

struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
return new llama_sampler {
return llama_sampler_init(
/* .iface = */ &llama_sampler_infill_i,
/* .ctx = */ new llama_sampler_infill {
/* .vocab = */ vocab,
/* .buf0 = */ std::vector<char>(512),
/* .buf1 = */ std::vector<char>(512),
},
};
}
);
}

// utils
Expand Down

0 comments on commit bf14ca7

Please sign in to comment.