Skip to content

Commit

Permalink
llama : improve infill sampler
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Oct 12, 2024
1 parent 8343eeb commit c7d8904
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 23 deletions.
7 changes: 4 additions & 3 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1152,9 +1152,9 @@ extern "C" {
const llama_logit_bias * logit_bias);

// this sampler is meant to be used for fill-in-the-middle infilling
// it's supposed to be used after top_k sampling and will leave a single candidate token
// it's supposed to be used after top_k sampling
//
// 1. if there is a high-prob token (>= 0.9f) -> pick it
// 1. if there is a high-prob token (>= 0.9f) -> skip step 2
// 2. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG
// 3. combine probs of tokens that have the same prefix
//
Expand All @@ -1170,7 +1170,8 @@ extern "C" {
// "hel": 0.8
// "dummy": 0.1
//
// 4. pick the token with the highest probability
// 4. discard non-EOG tokens with low prob (< 0.2)
// 5. if no tokens are left -> pick EOT
//
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);

Expand Down
44 changes: 24 additions & 20 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1663,7 +1663,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_

#if defined(GGML_DEBUG_SAMPLER_INFILL)
for (size_t i = 0; i < cur_p->size; ++i) {
LLAMA_LOG_DEBUG("infill: cur_p[%zu] = { id: %d, p: %f, logit: %f }\n", i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
LLAMA_LOG_DEBUG("infill: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
}
#endif

Expand All @@ -1673,14 +1673,16 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_

for (size_t i = 0; i < cur_p->size; ++i) {
p_max = fmaxf(p_max, cur_p->data[i].p);

if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
p_eog_sum += cur_p->data[i].p;
} else {
p_txt_sum += cur_p->data[i].p;
}
}

const float rat = p_txt_sum / p_eog_sum;
const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum;

LLAMA_LOG_DEBUG("infill: p_max = %.2f, p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", p_max, p_txt_sum, p_eog_sum, rat, cur_p->size);

if (p_max < 0.90f && p_eog_sum*cur_p->size > p_txt_sum) {
Expand Down Expand Up @@ -1712,48 +1714,50 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
}

if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) {
if (cur_p->data[i].p > cur_p->data[j].p) {
if (cur_p->data[i].p > cur_p->data[j].p) {
cur_p->data[i].p += cur_p->data[j].p;
cur_p->data[j].logit = -INFINITY;
cur_p->data[j].p = 0.0f;
} else {
cur_p->data[j].p += cur_p->data[i].p;
cur_p->data[i].logit = -INFINITY;
cur_p->data[i].p = 0.0f;
}
}
}
}

// mask non-EOG tokens with prob < 0.2
for (size_t i = 0; i < cur_p->size; ++i) {
const auto size_org = cur_p->size;

cur_p->size = 0;

float p_sum = 0.0f;

for (size_t i = 0; i < size_org; ++i) {
// discard non-EOG tokens with prob < 0.2
if (cur_p->data[i].p < 0.2 && !llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
cur_p->data[i].logit = -INFINITY;
continue;
}
}

// determine the token with max logit
float l_max = -INFINITY;
int i_max = -1;
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].logit > l_max) {
l_max = cur_p->data[i].logit;
i_max = i;
}
// keep this token
p_sum += cur_p->data[i].p;

cur_p->data[cur_p->size++] = cur_p->data[i];
}

// if all probs are -INFINITY -> reduce cur_p to single EOG token
if (i_max == -1) {
if (cur_p->size == 0) {
cur_p->size = 1;
cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
cur_p->data[0].logit = 1.0f;

return;
}

// pick the best token
cur_p->size = 1;
cur_p->data[0] = cur_p->data[i_max];

// normalize probs
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].p /= p_sum;

LLAMA_LOG_DEBUG("after : cur_p[%zu] = { id: %d, p: %f, logit: %f }\n", i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
}
}
Expand Down

0 comments on commit c7d8904

Please sign in to comment.