Skip to content

Commit

Permalink
llama : cache llama_token_to_piece
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed May 28, 2024
1 parent 0548a41 commit 92b88a0
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 69 deletions.
156 changes: 89 additions & 67 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1651,12 +1651,13 @@ struct llama_mlock {
};
using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;

static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
// NOTE: avoid ever using this except for building the token_to_piece caches
static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
std::vector<char> result(8, 0);
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special);
const int n_tokens = llama_token_to_piece(model, token, result.data(), result.size(), special);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special);
int check = llama_token_to_piece(model, token, result.data(), result.size(), special);
GGML_ASSERT(check == -n_tokens);
}
else {
Expand Down Expand Up @@ -2086,7 +2087,11 @@ struct llama_vocab {
std::unordered_map<token, id> token_to_id;
std::vector<token_data> id_to_token;

std::unordered_map<token, id> special_tokens_cache;
bool has_cache = false;

std::unordered_map<token, id> cache_special_tokens;
std::unordered_map<id, token> cache_token_to_piece; // llama_token_to_piece(special = false);
std::unordered_map<id, token> cache_token_to_piece_special; // llama_token_to_piece(special = true);

std::map<std::pair<std::string, std::string>, int> bpe_ranks;

Expand Down Expand Up @@ -4789,7 +4794,7 @@ static void llm_load_vocab(
// And skip the ones which are one character
if (utf8_str_len > 1) {
// At this point what we have left are special tokens only
vocab.special_tokens_cache[token] = id;
vocab.cache_special_tokens[token] = id;

// Count manually found special tokens
special_tokens_count_from_verification++;
Expand All @@ -4816,6 +4821,13 @@ static void llm_load_vocab(
);
}
}

for (llama_token id = 0; id < (llama_token) n_vocab; ++id) {
vocab.cache_token_to_piece[id] = llama_token_to_piece(&model, id, false);
vocab.cache_token_to_piece_special[id] = llama_token_to_piece(&model, id, true);
}

vocab.has_cache = true;
}

static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
Expand Down Expand Up @@ -12898,7 +12910,7 @@ struct fragment_buffer_variant {

static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
// for each special token
for (const auto & st: vocab.special_tokens_cache) {
for (const auto & st: vocab.cache_special_tokens) {
const auto & special_token = st.first;
const auto & special_id = st.second;

Expand Down Expand Up @@ -14058,7 +14070,7 @@ void llama_sample_repetition_penalties(

void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
GGML_ASSERT(ctx);
const int64_t t_start_sample_us = ggml_time_us();
int64_t t_start_sample_us = ggml_time_us();

bool allow_eog = false;
for (const auto & stack : grammar->stacks) {
Expand All @@ -14074,8 +14086,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
candidates_grammar.reserve(candidates->size);

for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id;
const std::string piece = llama_token_to_piece(ctx, id, false);
const llama_token id = candidates->data[i].id;
const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(id);

if (llama_token_is_eog(&ctx->model, id)) {
if (!allow_eog) {
Expand Down Expand Up @@ -14275,7 +14287,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
GGML_ASSERT(false);
}

const std::string piece = llama_token_to_piece(ctx, token, false);
const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(token);

// Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
Expand Down Expand Up @@ -17948,69 +17960,79 @@ static std::string llama_decode_text(const std::string & text) {

// does not write null-terminator to buf
int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special) {
if (model->vocab.has_cache) {
const auto & cache = special ? model->vocab.cache_token_to_piece_special : model->vocab.cache_token_to_piece;
const auto & res = cache.at(token);
if (length < (int) res.size()) {
return -(int) res.size();
}
memcpy(buf, res.c_str(), res.size());
return res.size();
}

if (0 <= token && token < llama_n_vocab(model)) {
switch (llama_vocab_get_type(model->vocab)) {
case LLAMA_VOCAB_TYPE_WPM:
case LLAMA_VOCAB_TYPE_SPM: {
// NOTE: we accept all unsupported token types,
// suppressing them like CONTROL tokens.
if (llama_is_normal_token(model->vocab, token)) {
std::string result = model->vocab.id_to_token[token].text;
llama_unescape_whitespace(result);
if (length < (int) result.length()) {
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (
(llama_is_user_defined_token(model->vocab, token)) ||
(llama_is_control_token (model->vocab, token) && special)) {
std::string result = model->vocab.id_to_token[token].text;
if (length < (int) result.length()) {
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
if (length < 3) {
return -3;
}
memcpy(buf, "\xe2\x96\x85", 3);
return 3;
} else if (llama_is_byte_token(model->vocab, token)) {
if (length < 1) {
return -1;
case LLAMA_VOCAB_TYPE_WPM:
case LLAMA_VOCAB_TYPE_SPM: {
// NOTE: we accept all unsupported token types,
// suppressing them like CONTROL tokens.
if (llama_is_normal_token(model->vocab, token)) {
std::string result = model->vocab.id_to_token[token].text;
llama_unescape_whitespace(result);
if (length < (int) result.length()) {
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (
(llama_is_user_defined_token(model->vocab, token)) ||
(llama_is_control_token (model->vocab, token) && special)) {
std::string result = model->vocab.id_to_token[token].text;
if (length < (int) result.length()) {
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
if (length < 3) {
return -3;
}
memcpy(buf, "\xe2\x96\x85", 3);
return 3;
} else if (llama_is_byte_token(model->vocab, token)) {
if (length < 1) {
return -1;
}
buf[0] = llama_token_to_byte(model->vocab, token);
return 1;
}
buf[0] = llama_token_to_byte(model->vocab, token);
return 1;
break;
}
break;
}
case LLAMA_VOCAB_TYPE_BPE: {
// NOTE: we accept all unsupported token types,
// suppressing them like CONTROL tokens.
if (llama_is_normal_token(model->vocab, token)) {
std::string result = model->vocab.id_to_token[token].text;
result = llama_decode_text(result);
if (length < (int) result.length()) {
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (
(llama_is_user_defined_token(model->vocab, token)) ||
(llama_is_control_token (model->vocab, token) && special)) {
std::string result = model->vocab.id_to_token[token].text;
if (length < (int) result.length()) {
return -(int) result.length();
case LLAMA_VOCAB_TYPE_BPE: {
// NOTE: we accept all unsupported token types,
// suppressing them like CONTROL tokens.
if (llama_is_normal_token(model->vocab, token)) {
std::string result = model->vocab.id_to_token[token].text;
result = llama_decode_text(result);
if (length < (int) result.length()) {
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (
(llama_is_user_defined_token(model->vocab, token)) ||
(llama_is_control_token (model->vocab, token) && special)) {
std::string result = model->vocab.id_to_token[token].text;
if (length < (int) result.length()) {
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
break;
}
break;
}
default:
GGML_ASSERT(false);
default:
GGML_ASSERT(false);
}
}
return 0;
Expand Down
4 changes: 2 additions & 2 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,8 @@ extern "C" {

LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);

LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);

LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
Expand Down

0 comments on commit 92b88a0

Please sign in to comment.