Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop the generation when <|eom_id|> token is encountered (needed for llama 3.1 tool call support) #8858

Merged
merged 5 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ class Tokenizer:
SUFFIX_ID = "tokenizer.ggml.suffix_token_id"
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
EOT_ID = "tokenizer.ggml.eot_token_id"
EOM_ID = "tokenizer.ggml.eom_token_id"

class Adapter:
TYPE = "adapter.type"
Expand Down Expand Up @@ -1327,3 +1328,4 @@ def get_type(val: Any) -> GGUFValueType:
KEY_TOKENIZER_SUFFIX_ID = Keys.Tokenizer.SUFFIX_ID
KEY_TOKENIZER_MIDDLE_ID = Keys.Tokenizer.MIDDLE_ID
KEY_TOKENIZER_EOT_ID = Keys.Tokenizer.EOT_ID
KEY_TOKENIZER_EOM_ID = Keys.Tokenizer.EOM_ID
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,9 @@ def add_middle_token_id(self, id: int) -> None:
def add_eot_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.EOT_ID, id)

def add_eom_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.EOM_ID, id)

def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
pack_prefix = ''
if not skip_pack_prefix:
Expand Down
7 changes: 6 additions & 1 deletion src/llama-vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1444,7 +1444,8 @@ llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, lla
bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
return token != -1 && (
token == llama_token_eos_impl(vocab) ||
token == llama_token_eot_impl(vocab)
token == llama_token_eot_impl(vocab) ||
token == llama_token_eom_impl(vocab)
);
}

Expand Down Expand Up @@ -1500,6 +1501,10 @@ llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
return vocab.special_eot_id;
}

llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
return vocab.special_eom_id;
}

int32_t llama_tokenize_impl(
const struct llama_vocab & vocab,
const char * text,
Expand Down
2 changes: 2 additions & 0 deletions src/llama-vocab.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ struct llama_vocab {
id special_suffix_id = -1;
id special_middle_id = -1;
id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
id special_eom_id = -1;

// tokenizer flags
bool tokenizer_add_space_prefix = false;
Expand Down Expand Up @@ -101,6 +102,7 @@ llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
llama_token llama_token_eot_impl (const struct llama_vocab & vocab);
llama_token llama_token_eom_impl (const struct llama_vocab & vocab);

int32_t llama_tokenize_impl(
const struct llama_vocab & vocab,
Expand Down
14 changes: 14 additions & 0 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ enum llm_kv {
LLM_KV_TOKENIZER_SUFFIX_ID,
LLM_KV_TOKENIZER_MIDDLE_ID,
LLM_KV_TOKENIZER_EOT_ID,
LLM_KV_TOKENIZER_EOM_ID,

LLM_KV_ADAPTER_TYPE,
LLM_KV_ADAPTER_LORA_ALPHA,
Expand Down Expand Up @@ -459,6 +460,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" },
{ LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" },
{ LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" },
{ LLM_KV_TOKENIZER_EOM_ID, "tokenizer.ggml.eom_token_id" },

{ LLM_KV_ADAPTER_TYPE, "adapter.type" },
{ LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" },
Expand Down Expand Up @@ -5586,6 +5588,7 @@ static void llm_load_vocab(
{ LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_suffix_id },
{ LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id },
{ LLM_KV_TOKENIZER_EOT_ID, vocab.special_eot_id },
{ LLM_KV_TOKENIZER_EOM_ID, vocab.special_eom_id },
};

for (const auto & it : special_token_types) {
Expand Down Expand Up @@ -5638,6 +5641,17 @@ static void llm_load_vocab(
}
}
}

// find EOM token: "<|eom_id|>"
//
// TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOM_ID
// for now, we apply this workaround to find the EOM token based on its text
if (vocab.special_eom_id == -1) {
const auto & t = vocab.token_to_id.find("<|eom_id|>");
if (t != vocab.token_to_id.end()) {
vocab.special_eom_id = t->second;
}
}
}

// build special tokens cache
Expand Down
Loading