Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : cache llama_token_to_piece #7587

Merged
merged 4 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 92 additions & 69 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1702,12 +1702,13 @@ struct llama_mlock {
};
using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;

static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
// NOTE: avoid ever using this except for building the token_to_piece caches
static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
std::vector<char> result(8, 0);
const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special);
const int n_tokens = llama_token_to_piece(model, token, result.data(), result.size(), special);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), special);
int check = llama_token_to_piece(model, token, result.data(), result.size(), special);
GGML_ASSERT(check == -n_tokens);
}
else {
Expand Down Expand Up @@ -2162,7 +2163,11 @@ struct llama_vocab {
std::unordered_map<token, id> token_to_id;
std::vector<token_data> id_to_token;

std::vector<id> special_tokens_cache;
bool has_cache = false;
Copy link
Collaborator

@HanClinto HanClinto May 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a mechanism by which the vocab can be loaded without having a cache in place? If not, I'm wondering if has_cache is useful right now...?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a way to exit early before creating the cache if the tokenizer was unknown. I've removed this path by throwing an exception: 1494a18

There is another path where the GGUF explicitly does not contain a vocabulary: "no_vocab". In that case calling any of the functions that rely on a cache would throw exception due to accessing the caches via cache.at(). I think this makes sense

Removed has_cache and replaced the unordered maps with vectors

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, thank you!

I like all your changes here -- this all feels really good. The only other thing that I'll note is the caveat that I noted on @ochafik 's similar PR in #6811 :

#6811 (comment)

I think it'd be simpler to leave it as is and keep it as an area where to potentially squeeze a couple of MB when times are scarce. wdyt?

This also sounds not unreasonable, but I don't know how to weigh such things. I know I really like grammar-constrained sampling, but I don't know how popular the feature is overall, and is it worth negatively impacting hyper-resource-constrained usages (such as Raspberry Pis or whatnot) vs. grammars? That's what I'm unable to weigh -- I feel like that's a strategic decision that's a bit above my level.

In short, we don't need the cache for situations that don't use grammars, and we're adding a bit of memory usage (n_vocab*2) to every context that we're creating. On most systems this isn't a problem, but on highly-constrained systems (such as Raspberry Pi and whatnot) then this is wasted memory.

How do we weigh the interests of memory-constrained users vs. grammar-enabled users? That's something that I'm not able to make, but overall I think that improving speed performance on grammar-enabled sampling is going to benefit the largest number of people, and the ultra-constrained users are going to be pretty small. We might want to make a note somewhere in a comment that if one is looking for a way to decrease memory usage that they could disable the caching, but beyond that we're probably fine with kicking that can down the road.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a log for the memory usage of the "token to piece" caches:

# llama 3
llm_load_vocab: token to piece cache size = 1.5928 MB

I think this is completely fine and no need to worry about it for now

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent, thank you! That was the one reservation that held me back from fully approving #6811 (I felt that choice required someone with a larger project scope than I have), so I'm very happy to have you weigh in on that.


std::vector<id> cache_special_tokens;
HanClinto marked this conversation as resolved.
Show resolved Hide resolved
std::unordered_map<id, token> cache_token_to_piece; // llama_token_to_piece(special = false);
std::unordered_map<id, token> cache_token_to_piece_special; // llama_token_to_piece(special = true);

std::map<std::pair<std::string, std::string>, int> bpe_ranks;

Expand Down Expand Up @@ -4833,18 +4838,26 @@ static void llm_load_vocab(
{
for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) {
vocab.special_tokens_cache.push_back(id);
vocab.cache_special_tokens.push_back(id);
}
}

std::sort( vocab.special_tokens_cache.begin(), vocab.special_tokens_cache.end(),
std::sort( vocab.cache_special_tokens.begin(), vocab.cache_special_tokens.end(),
[&] (const llama_vocab::id a, const llama_vocab::id b) {
return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
}
);

LLAMA_LOG_INFO("%s: special tokens cache size = %u.\n", __func__, (uint32_t)vocab.special_tokens_cache.size());
LLAMA_LOG_INFO("%s: special tokens cache size = %u.\n", __func__, (uint32_t)vocab.cache_special_tokens.size());
}

// build token to piece caches
for (llama_token id = 0; id < (llama_token) n_vocab; ++id) {
vocab.cache_token_to_piece[id] = llama_token_to_piece(&model, id, false);
vocab.cache_token_to_piece_special[id] = llama_token_to_piece(&model, id, true);
}

vocab.has_cache = true;
}

static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
Expand Down Expand Up @@ -13233,7 +13246,7 @@ struct fragment_buffer_variant {

static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
// for each special token
for (const llama_vocab::id special_id : vocab.special_tokens_cache) {
for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
const auto & special_token = vocab.id_to_token[special_id].text;

// for each text fragment
Expand Down Expand Up @@ -14392,7 +14405,7 @@ void llama_sample_repetition_penalties(

void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
GGML_ASSERT(ctx);
const int64_t t_start_sample_us = ggml_time_us();
int64_t t_start_sample_us = ggml_time_us();

bool allow_eog = false;
for (const auto & stack : grammar->stacks) {
Expand All @@ -14408,8 +14421,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
candidates_grammar.reserve(candidates->size);

for (size_t i = 0; i < candidates->size; ++i) {
const llama_token id = candidates->data[i].id;
const std::string piece = llama_token_to_piece(ctx, id, false);
const llama_token id = candidates->data[i].id;
const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(id);
HanClinto marked this conversation as resolved.
Show resolved Hide resolved

if (llama_token_is_eog(&ctx->model, id)) {
if (!allow_eog) {
Expand Down Expand Up @@ -14609,7 +14622,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
GGML_ASSERT(false);
}

const std::string piece = llama_token_to_piece(ctx, token, false);
const std::string & piece = ctx->model.vocab.cache_token_to_piece.at(token);

// Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar->partial_utf8);
Expand Down Expand Up @@ -18292,69 +18305,79 @@ static std::string llama_decode_text(const std::string & text) {

// does not write null-terminator to buf
int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special) {
if (model->vocab.has_cache) {
const auto & cache = special ? model->vocab.cache_token_to_piece_special : model->vocab.cache_token_to_piece;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Maybe we could get away w/ a single cache (built w/ special=true) and early-exit in special case at the top of the function?

int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special) {
    if (!special && llama_is_control_token(model->vocab, token)) {
        return 0;
    }
    // if we have a cache - use it
    if (!model->vocab.cache_token_to_piece.empty()) {
         ....
    }
    ...

const auto & res = cache.at(token);
if (length < (int) res.size()) {
return -(int) res.size();
}
memcpy(buf, res.c_str(), res.size());
return res.size();
}

if (0 <= token && token < llama_n_vocab(model)) {
switch (llama_vocab_get_type(model->vocab)) {
case LLAMA_VOCAB_TYPE_WPM:
case LLAMA_VOCAB_TYPE_SPM: {
// NOTE: we accept all unsupported token types,
// suppressing them like CONTROL tokens.
if (llama_is_normal_token(model->vocab, token)) {
std::string result = model->vocab.id_to_token[token].text;
llama_unescape_whitespace(result);
if (length < (int) result.length()) {
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (
(llama_is_user_defined_token(model->vocab, token)) ||
(llama_is_control_token (model->vocab, token) && special)) {
std::string result = model->vocab.id_to_token[token].text;
if (length < (int) result.length()) {
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
if (length < 3) {
return -3;
}
memcpy(buf, "\xe2\x96\x85", 3);
return 3;
} else if (llama_is_byte_token(model->vocab, token)) {
if (length < 1) {
return -1;
case LLAMA_VOCAB_TYPE_WPM:
case LLAMA_VOCAB_TYPE_SPM: {
// NOTE: we accept all unsupported token types,
// suppressing them like CONTROL tokens.
if (llama_is_normal_token(model->vocab, token)) {
std::string result = model->vocab.id_to_token[token].text;
llama_unescape_whitespace(result);
if (length < (int) result.length()) {
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (
(llama_is_user_defined_token(model->vocab, token)) ||
(llama_is_control_token (model->vocab, token) && special)) {
std::string result = model->vocab.id_to_token[token].text;
if (length < (int) result.length()) {
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
if (length < 3) {
return -3;
}
memcpy(buf, "\xe2\x96\x85", 3);
return 3;
} else if (llama_is_byte_token(model->vocab, token)) {
if (length < 1) {
return -1;
}
buf[0] = llama_token_to_byte(model->vocab, token);
return 1;
}
buf[0] = llama_token_to_byte(model->vocab, token);
return 1;
break;
}
break;
}
case LLAMA_VOCAB_TYPE_BPE: {
// NOTE: we accept all unsupported token types,
// suppressing them like CONTROL tokens.
if (llama_is_normal_token(model->vocab, token)) {
std::string result = model->vocab.id_to_token[token].text;
result = llama_decode_text(result);
if (length < (int) result.length()) {
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (
(llama_is_user_defined_token(model->vocab, token)) ||
(llama_is_control_token (model->vocab, token) && special)) {
std::string result = model->vocab.id_to_token[token].text;
if (length < (int) result.length()) {
return -(int) result.length();
case LLAMA_VOCAB_TYPE_BPE: {
// NOTE: we accept all unsupported token types,
// suppressing them like CONTROL tokens.
if (llama_is_normal_token(model->vocab, token)) {
std::string result = model->vocab.id_to_token[token].text;
result = llama_decode_text(result);
if (length < (int) result.length()) {
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
} else if (
(llama_is_user_defined_token(model->vocab, token)) ||
(llama_is_control_token (model->vocab, token) && special)) {
std::string result = model->vocab.id_to_token[token].text;
if (length < (int) result.length()) {
return -(int) result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
}
memcpy(buf, result.c_str(), result.length());
return result.length();
break;
}
break;
}
default:
GGML_ASSERT(false);
default:
GGML_ASSERT(false);
}
}
return 0;
Expand Down
4 changes: 2 additions & 2 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,8 @@ extern "C" {

LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);

LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);

LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
Expand Down
Loading