diff --git a/include/llama.h b/include/llama.h index efbc6314c6988..6081297c7a2fd 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1152,9 +1152,9 @@ extern "C" { const llama_logit_bias * logit_bias); // this sampler is meant to be used for fill-in-the-middle infilling - // it's supposed to be used after top_k sampling and will leave a single candidate token + // it's supposed to be used after top_k sampling // - // 1. if there is a high-prob token (>= 0.9f) -> pick it + // 1. if there is a high-prob token (>= 0.9f) -> skip step 2 // 2. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG // 3. combine probs of tokens that have the same prefix // @@ -1170,7 +1170,8 @@ extern "C" { // "hel": 0.8 // "dummy": 0.1 // - // 4. pick the token with the highest probability + // 4. discard non-EOG tokens with low prob (< 0.2) + // 5. if no tokens are left -> pick EOT // LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index d0c351fd72066..60d8a9e709f93 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1663,7 +1663,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_ #if defined(GGML_DEBUG_SAMPLER_INFILL) for (size_t i = 0; i < cur_p->size; ++i) { - LLAMA_LOG_DEBUG("infill: cur_p[%zu] = { id: %d, p: %f, logit: %f }\n", i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); + LLAMA_LOG_DEBUG("infill: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); } #endif @@ -1673,6 +1673,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_ for (size_t i = 0; i < cur_p->size; ++i) { p_max = fmaxf(p_max, cur_p->data[i].p); + if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) { p_eog_sum += cur_p->data[i].p; } else { @@ -1680,7 +1681,8 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_ } } - const float rat = p_txt_sum / p_eog_sum; + const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; + LLAMA_LOG_DEBUG("infill: p_max = %.2f, p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", p_max, p_txt_sum, p_eog_sum, rat, cur_p->size); if (p_max < 0.90f && p_eog_sum*cur_p->size > p_txt_sum) { @@ -1712,36 +1714,39 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_ } if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) { - if (cur_p->data[i].p > cur_p->data[j].p) { + if (cur_p->data[i].p > cur_p->data[j].p) { cur_p->data[i].p += cur_p->data[j].p; cur_p->data[j].logit = -INFINITY; + cur_p->data[j].p = 0.0f; } else { cur_p->data[j].p += cur_p->data[i].p; cur_p->data[i].logit = -INFINITY; + cur_p->data[i].p = 0.0f; } } } } - // mask non-EOG tokens with prob < 0.2 - for (size_t i = 0; i < cur_p->size; ++i) { + const auto size_org = cur_p->size; + + cur_p->size = 0; + + float p_sum = 0.0f; + + for (size_t i = 0; i < size_org; ++i) { + // discard non-EOG tokens with prob < 0.2 if (cur_p->data[i].p < 0.2 && !llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) { - cur_p->data[i].logit = -INFINITY; + continue; } - } - // determine the token with max logit - float l_max = -INFINITY; - int i_max = -1; - for (size_t i = 0; i < cur_p->size; ++i) { - if (cur_p->data[i].logit > l_max) { - l_max = cur_p->data[i].logit; - i_max = i; - } + // keep this token + p_sum += cur_p->data[i].p; + + cur_p->data[cur_p->size++] = cur_p->data[i]; } // if all probs are -INFINITY -> reduce cur_p to single EOG token - if (i_max == -1) { + if (cur_p->size == 0) { cur_p->size = 1; cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab); cur_p->data[0].logit = 1.0f; @@ -1749,11 +1754,10 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_ return; } - // pick the best token - cur_p->size = 1; - cur_p->data[0] = cur_p->data[i_max]; - + // normalize probs for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= p_sum; + LLAMA_LOG_DEBUG("after : cur_p[%zu] = { id: %d, p: %f, logit: %f }\n", i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); } }