diff --git a/common/sampling.cpp b/common/sampling.cpp index c626ca03c11e1..bb7de3769d145 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -433,10 +433,10 @@ static llama_token_data_array llama_sampling_prepare_impl( { const int penalty_tokens_used_size = std::min(penalty_tokens.size(), (size_t)dry_penalty_last_n); if (penalty_tokens_used_size) { - llama_sample_dry(&cur_p, + llama_sample_dry(ctx_main, &cur_p, penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, penalty_tokens_used_size, dry_base, dry_multiplier, dry_allowed_length, - params.dry_seq_breakers.data(), params.dry_seq_breakers.size()); + params.dry_seq_breakers); } } diff --git a/common/sampling.h b/common/sampling.h index 80c2568cf2b41..1f864c4764764 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -46,6 +46,8 @@ typedef struct llama_sampling_params { uint32_t dry_allowed_length = 2; int32_t dry_penalty_last_n = -1; // DRY last n tokens to penalize (0 = disable penalty, -1 = context size) + std::vector dry_seq_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY + std::vector samplers_sequence = { llama_sampler_type::TOP_K, llama_sampler_type::TFS_Z, @@ -63,9 +65,8 @@ typedef struct llama_sampling_params { float cfg_scale = 1.f; // how strong is guidance std::unordered_map logit_bias; // logit bias for specific tokens - std::vector penalty_prompt_tokens; - std::vector dry_seq_breakers; // sequence breakers for the DRY sampler + bool use_penalty_prompt_tokens = false; } llama_sampling_params; diff --git a/include/llama.h b/include/llama.h index 51ed8d9ee2402..81805e5c2ae7e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1085,16 +1085,17 @@ extern "C" { float p, size_t min_keep); - /// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677 - LLAMA_API void llama_sample_dry( - llama_token_data_array * candidates, - const llama_token * last_tokens, - size_t last_tokens_size, - float dry_base, - float dry_multiplier, - int dry_allowed_length, - const llama_token * dry_seq_breakers, - size_t dry_seq_breakers_size); + // /// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677 + // LLAMA_API void llama_sample_dry( + // struct llama_context * ctx, + // llama_token_data_array * candidates, + // const llama_token * last_tokens, + // size_t last_tokens_size, + // float dry_base, + // float dry_multiplier, + // int dry_allowed_length, + // const std::vector + // & dry_seq_breakers); /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. LLAMA_API void llama_sample_tail_free( @@ -1246,6 +1247,18 @@ std::pair, llama_partial_utf8> decode_utf8( // This is a temporary workaround in order to fix race conditions when sampling with multiple sequences. llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng); +/// @details DRY sampler as described in: https://github.com/oobabooga/text-generation-webui/pull/5677 +LLAMA_API void llama_sample_dry( + struct llama_context * ctx, + llama_token_data_array * candidates, + const llama_token * last_tokens, + size_t last_tokens_size, + float dry_base, + float dry_multiplier, + int dry_allowed_length, + const std::vector + & dry_seq_breakers); + #endif // LLAMA_API_INTERNAL #endif // LLAMA_H diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 375717accbd5d..22e623fac7656 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -232,94 +232,230 @@ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_arra } } -void llama_sample_dry_impl(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size) { - // skip dry sampler if we don't have a previous token - if (last_tokens_size < 1) return; +std::vector llama_tokenize( + const struct llama_context * ctx, + const std::string & text, + bool add_special, + bool parse_special) { + return llama_tokenize(llama_get_model(ctx), text, add_special, parse_special); +} + +std::vector llama_tokenize( + const struct llama_model * model, + const std::string & text, + bool add_special, + bool parse_special) { + // upper limit for the number of tokens + int n_tokens = text.length() + 2 * add_special; + std::vector result(n_tokens); + n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); + if (n_tokens < 0) { + result.resize(-n_tokens); + int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); + GGML_ASSERT(check == -n_tokens); + } else { + result.resize(n_tokens); + } + return result; +} + +std::string llama_detokenize(llama_context * ctx, const std::vector & tokens, bool special) { + std::string text; + text.resize(std::max(text.capacity(), tokens.size())); + int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + if (n_chars < 0) { + text.resize(-n_chars); + n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization + } + + text.resize(n_chars); + + // NOTE: the original tokenizer decodes bytes after collecting the pieces. + return text; +} + +std::string llama_detokenize_single(llama_context * ctx, llama_token token, bool special) { + std::vector tokens = {token}; + return llama_detokenize(ctx, tokens, special); +} - // get the last token - auto last_token = last_tokens[last_tokens_size - 1]; +// Constants for preventing overflow +const float FLOAT_MAX_LOG = 88.7228391f; +const int MAX_CHAR_LEN = 40; +const int MAX_SEQ_LEN = 20; - // if last token is part of the sequence breakers, skip whole sampler - if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, last_token) != dry_seq_breakers + dry_seq_breakers_size) { + +void llama_sample_dry_impl(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const std::vector & dry_seq_breakers) { + if (last_tokens_size < 1) { return; } - // create an unordered map of "next tokens" <-> max match length + // Cache for token-to-string conversions + std::unordered_map token_to_string_cache; + // Store sequence breakers for more efficient lookup + std::unordered_multimap> restart_sequences; + + auto detokenize_with_cache = [&](llama_token token) -> std::string { + auto it = token_to_string_cache.find(token); + if (it != token_to_string_cache.end()) { + return it->second; + } + std::string token_str = llama_detokenize_single(ctx, token, false); + token_to_string_cache[token] = token_str; + return token_str; + }; + + // Pre-process dry_seq_breakers + for (const auto& breaker : dry_seq_breakers) { + std::string breaker_trimmed = breaker.substr(0, MAX_CHAR_LEN); + std::vector tokens = llama_tokenize(ctx, breaker_trimmed, false, false); + + if (!tokens.empty()) { + std::string head = detokenize_with_cache(tokens[0]); + std::vector tail; + + for (size_t i = 1; i < tokens.size() && i <= MAX_SEQ_LEN; ++i) { + tail.push_back(detokenize_with_cache(tokens[i])); + } + restart_sequences.emplace(head, tail); + } + } + + // Find max repetition length considering restart sequences + int rep_limit = last_tokens_size; + + for (size_t i = 0; i < last_tokens_size; ++i) { + size_t ix = last_tokens_size - 1 - i; + std::string token_str = detokenize_with_cache(last_tokens[ix]); + + // Check if the token is a potential sequence breaker + auto its = restart_sequences.equal_range(token_str); + if (its.first == restart_sequences.end()) continue; + + int longest_match = -1; + // Check all potential sequence breakers starting with this token + for (auto it = its.first; it != its.second; ++it) { + int seq_len = (int)it->second.size(); + if (seq_len > longest_match && seq_len <= i) { + bool match = true; + // Check if the following tokens match the sequence breaker + for (size_t offset = 0; offset < seq_len; ++offset) { + if (it->second[offset] != detokenize_with_cache(last_tokens[ix + 1 + offset])) { + match = false; + break; + } + } + if (match) { + longest_match = seq_len; + } + } + } + + if (longest_match >= 0) { + rep_limit = static_cast(i) - longest_match; + break; + } + } + + if (rep_limit <= dry_allowed_length) { + return; + } + + // Store max match length for each token std::unordered_map match_lengths; - // loop through each previous token (exclude the last token) + // Find repeated sequences for (size_t i = 0; i < last_tokens_size - 1; ++i) { - // skip if the compare token is not the same as the last token - if (last_tokens[i] != last_token) { + if (last_tokens[i] != last_tokens[last_tokens_size - 1]) { continue; } - // get the next token (i + 1 is always less than last_tokens_size) auto next_token = last_tokens[i + 1]; + std::string next_token_str = detokenize_with_cache(next_token); - // if next token is part of the sequence breakers, skip - if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, next_token) != dry_seq_breakers + dry_seq_breakers_size) { + // Skip if next token is a sequence breaker + auto its = restart_sequences.equal_range(next_token_str); + if (its.first != restart_sequences.end()) { continue; } - // try to extend the match backwards (match length starts at 1 because last token is already matched) size_t match_length = 1; - // loop through the previous tokens + // Extend match as far as possible for (;; match_length++) { - // if we have reached the start of our last tokens, break - if (i < match_length) break; + if (i < match_length || match_length > rep_limit) { + break; + } - // compare token starts at our prev index, going backwards by match length auto compare_token = last_tokens[i - match_length]; + std::string compare_token_str = detokenize_with_cache(compare_token); - // head token starts at the end of last tokens, going backwards by match length, minus 1 because we start at the last token itself auto head_token = last_tokens[last_tokens_size - 1 - match_length]; + std::string head_token_str = detokenize_with_cache(head_token); - // break out of the match if any tokens don't match - if (compare_token != head_token) { + if (compare_token_str != head_token_str) { break; } - // if compare token is part of the sequence breakers, break out of the match - if (std::find(dry_seq_breakers, dry_seq_breakers + dry_seq_breakers_size, compare_token) != dry_seq_breakers + dry_seq_breakers_size) { + // Check if we've hit a sequence breaker + its = restart_sequences.equal_range(compare_token_str); + if (its.first != restart_sequences.end()) { break; } } - // Check if the next token exists in the map + // Update max match length for this token auto it = match_lengths.find(next_token); - if (it == match_lengths.end()) { - // Key does not exist, insert the new value match_lengths[next_token] = match_length; } else { - // Key exists, update it with the max of the new value or the existing value it->second = std::max(it->second, match_length); } } - // apply penalties + // Calculate max safe exponent + int max_exponent = 0; + if (dry_base > 1.000001f) { + max_exponent = static_cast(FLOAT_MAX_LOG / log(dry_base)); + } + +#ifdef DEBUG + LLAMA_LOG_INFO("DRY Sampling parameters:\n"); + LLAMA_LOG_INFO(" dry_base: %f\n", dry_base); + LLAMA_LOG_INFO(" dry_multiplier: %f\n", dry_multiplier); + LLAMA_LOG_INFO(" dry_allowed_length: %d\n", dry_allowed_length); + LLAMA_LOG_INFO(" max_exponent: %d\n", max_exponent); + LLAMA_LOG_INFO("DRY penalties ["); +#endif + + // Apply penalties for (const auto& pair : match_lengths) { auto next_token = pair.first; auto match_length = pair.second; - // if the match length is greater than or equal to our allowed length in config, we apply penalities - if (match_length >= (size_t)dry_allowed_length) { - - // find our next token in the candidates->data + if (match_length >= static_cast(dry_allowed_length)) { for (size_t i = 0; i < candidates->size; ++i) { if (candidates->data[i].id == next_token) { - // calculate the penalty - float penalty = dry_multiplier * pow(dry_base, match_length - dry_allowed_length); - - // apply the dry penalty + int repeat_exp = static_cast(match_length - dry_allowed_length); + if (max_exponent > 0 && repeat_exp > max_exponent) { + repeat_exp = max_exponent; + } + float penalty = dry_multiplier * pow(dry_base, static_cast(repeat_exp)); candidates->data[i].logit -= penalty; + +#ifdef DEBUG + LLAMA_LOG_INFO(" Token %d: %s (Penalty: %.2f)", next_token, detokenize_with_cache(next_token).c_str(), penalty); +#endif break; } } } } + +#ifdef DEBUG + LLAMA_LOG_INFO("]\n"); +#endif } void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) { diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 578c472438709..48cdc086c820c 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -28,7 +28,11 @@ void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_ void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep); void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sample_dry_impl (llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size); +std::vector llama_tokenize(const struct llama_context * ctx, const std::string & text, bool add_special, bool parse_special); +std::vector llama_tokenize(const struct llama_model * model, const std::string & text, bool add_special, bool parse_special); +std::string llama_detokenize(llama_context * ctx, const std::vector & tokens, bool special); +std::string llama_detokenize_single(llama_context * ctx, llama_token token, bool special); +void llama_sample_dry_impl (struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const std::vector & dry_seq_breakers); void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep); void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); diff --git a/src/llama.cpp b/src/llama.cpp index d6d03fe0406b8..a8a97c0905dc4 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -18948,8 +18948,8 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can llama_sample_min_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep); } -void llama_sample_dry(llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const llama_token * dry_seq_breakers, size_t dry_seq_breakers_size) { - llama_sample_dry_impl(candidates, last_tokens, last_tokens_size, dry_base, dry_multiplier, dry_allowed_length, dry_seq_breakers, dry_seq_breakers_size); +void llama_sample_dry(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float dry_base, float dry_multiplier, int dry_allowed_length, const std::vector & dry_seq_breakers) { + llama_sample_dry_impl(ctx, candidates, last_tokens, last_tokens_size, dry_base, dry_multiplier, dry_allowed_length, dry_seq_breakers); } void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) {