Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: SPM tokenization breaks in at least one specific case. #7629

Closed
snichols opened this issue May 29, 2024 · 16 comments
Closed

Bug: SPM tokenization breaks in at least one specific case. #7629

snichols opened this issue May 29, 2024 · 16 comments
Labels
bug-unconfirmed medium severity Used to report medium severity bugs in llama.cpp (e.g. Malfunctioning Features but still useable) stale

Comments

@snichols
Copy link
Contributor

snichols commented May 29, 2024

What happened?

Consider this code snippet:

auto chat_ml_tokens = llama_tokenize(model, "<|im_start|><|im_end|>\n", false, true);
std::cout << "chat_ml_tokens found:";
for(const auto t : chat_ml_tokens) {
    std::cout << " " << t;
}

With the latest version this is generating the following output:

chat_ml_tokens found: 32001 32000 28705 13

In an earlier version of llama.cpp, the correct tokenization was generated:

chat_ml_tokens found: 32001 32000 13

This work is based on https://huggingface.co/cognitivecomputations/dolphin-2.8-mistral-7b-v02.

If I tokenize each component separately, I get the correct results for each token. However, tokenizing <|im_end|>\n results in an extra 28705 token in the output. Interestingly enough, <|im_start|>\n is also correct. There's something extra special about this <|im_end|>. I haven't methodically gone over previous commits to see when this problem was introduced. Let me know if that'll help narrow the cause down.

I'm pretty confident that I can work around this problem just by tokenizing each element separately. I'll do that and run the model through some tests. That being said, there may be some other tokenization issues in the code that are being surfaced by this.

Name and Version

This is a custom app using tag b3040 of llama.cpp. https://github.com/ggerganov/llama.cpp/releases/tag/b3040

What operating system are you seeing the problem on?

Linux

Relevant log output

No response

@snichols snichols added bug-unconfirmed medium severity Used to report medium severity bugs in llama.cpp (e.g. Malfunctioning Features but still useable) labels May 29, 2024
@ggerganov
Copy link
Owner

The 2 tokens <|im_start|> and <|im_end|> in this model have a different special property - one is true the other is false:

https://huggingface.co/cognitivecomputations/dolphin-2.8-mistral-7b-v02/blob/6e0cd64b2c341a2328de00af2cd9d74cfcc89b74/tokenizer_config.json#L30-L45

Looks like a mistake in the model configuration to me

@snichols
Copy link
Contributor Author

That is strange. I'll convert the model again with that fixed.

I'm not sure why that'd cause this problem though. Because it's <|im_end|>\n that's exhibiting the bug and that token is marked special.

@jaime-m-p
Copy link
Collaborator

jaime-m-p commented May 30, 2024

I found more known problems. Fails with "</s> a".
The property rstrip is different.

phi-3/tokenizer_config.json:

    "2": {
      "content": "</s>",
      "lstrip": false,
      "normalized": false,
      "rstrip": true,  <--
      "single_word": false,
      "special": false
    },

dolphin-2.8-mistral-7b-v02/tokenizer_config.json:

    "2": {
      "content": "</s>",
      "lstrip": false,
      "normalized": false,
      "rstrip": false,  <--
      "single_word": false,
      "special": true
    },

In phi-3 all added tokens have rstrip: true except UNK <unk>, BOS <s> and EOS <|endoftext|>.
in dolphin-2.8-mistral-7b-v02 all added tokens have rstrip: false,
Also llama-spm has all added tokens rstrip to false, but seems to works.

I don't see any way to solve this correctly without having per-token flags/properties.

The only thing I can do is to initialize rstrip to false (this should revert to previous behavior).
Then use a harcoded IF model_name == "phy-3: ..." to fix other models.

@snichols
Copy link
Contributor Author

snichols commented Jun 1, 2024

It seems that if we want to be able to handle arbitrary models then we need to level up the tokenizer to take all of the data into account. Hard-coding model names in the tokenizer might be a quick fix, but it's gonna need ongoing maintenance. That's fragile and messy. But, it's better than it just plain not working I guess.

In my case, I'm skipping llama_tokenize and going right to the low-level tokenizer to take control over the process.

@jaime-m-p
Copy link
Collaborator

Hard-coding model names in the tokenizer might be a quick fix, but it's gonna need ongoing maintenance. That's fragile and messy. But, it's better than it just plain not working I guess.

I agree, meanwhile #7685.

@giladgd
Copy link
Contributor

giladgd commented Jun 2, 2024

@snichols I think #7697 will solve the issue you encountered.
Can you please check that?

@jaime-m-p
Copy link
Collaborator

@snichols Should work now (3b38d48).

@snichols
Copy link
Contributor Author

snichols commented Jun 4, 2024

@jaime-m-p @giladgd I'm verifying now.

@snichols
Copy link
Contributor Author

snichols commented Jun 4, 2024

The problem persists in 3b38d4. No change.

@jaime-m-p
Copy link
Collaborator

In an earlier version of llama.cpp, the correct tokenization was generated:
chat_ml_tokens found: 32001 32000 13

@snichols Actually this is the unexpected output.

The correct output seems to have an additional space:

model = "./models/tokenizers/dolphin-2.8-mistral-7b-v02"
tokenizer = AutoTokenizer.from_pretrained(model)

text = "<|im_start|><|im_end|>\n"
ids  = tokenizer.encode(text, add_special_tokens=False)
re   = tokenizer.decode(ids)

print(repr(text))  # '<|im_start|><|im_end|>\n'
print(repr(re))    # '<|im_start|><|im_end|> \n'
print(ids)         # [32001, 32000, 28705, 13]

@snichols
Copy link
Contributor Author

snichols commented Jun 5, 2024

@jaime-m-p Well, then all is as expected. Woot!

@github-actions github-actions bot added the stale label Jul 6, 2024
@shibe2
Copy link
Contributor

shibe2 commented Jul 6, 2024

@jaime-m-p What is considered "correct" is not universal. For example, if I add legacy=False parameter, it gives a different result:

from transformers import AutoTokenizer

model = "cognitivecomputations/dolphin-2.8-mistral-7b-v02"
tokenizer = AutoTokenizer.from_pretrained(model, legacy=False)

text = "<|im_start|><|im_end|>\n"
ids  = tokenizer.encode(text, add_special_tokens=False)
re   = tokenizer.decode(ids)

print(repr(text))  # '<|im_start|><|im_end|>\n'
print(repr(re))    # '<|im_start|><|im_end|>\n'
print(ids)         # [32001, 32000, 13]

The situation with handling of spaces in SPM tokenizer is unfortunate, see #3664. My position, based on practical approach, is that instead of trying to follow some standard here, it is better to give each model what it expects. If a particular model works better with a space after im_end, give it [32000, 28705, 13]. If it works better without the space, give it [32000, 13]. I see the following approaches to achieve that:

  • Make space insertion configurable as requested in #​3664.
  • Record it in the model file, whether spaces should be inserted. tokenizer.ggml.add_space_prefix may be used for that. This would probably require specifying it manually during conversion, preferably with testing of both options to see which works better with the model.
  • Just don't insert spaces and have the client or user put them where needed. This is the one that I currently follow. While this is not an ideal solution, on the client side, it is much easier to add spaces where needed than to get rid of them after tokenization when they are unneeded.

@github-actions github-actions bot removed the stale label Jul 7, 2024
@jaime-m-p
Copy link
Collaborator

@shibe2

instead of trying to follow some standard here, it is better to give each model what it expects

I think I understand what you are stating here, but isn't it contradictory?
As I reinterpret it:
What the model expects? --> training and fine tunning parameters --> config json files.

If a particular model works better with a space after im_end, give it [32000, 28705, 13]. If it works better without the space, give it [32000, 13].

Somehow this better alternative should be stored in the config json files.

if I add legacy=False parameter, it gives a different result

Then we need to manage this config parameter too (currently not implemented in Llama.cpp).

Also note, in this particular case for dolphin-2.8-mistral-7b-v02, that legacy=False is not what model expects (see https://huggingface.co/cognitivecomputations/dolphin-2.8-mistral-7b-v02/blob/6e0cd64b2c341a2328de00af2cd9d74cfcc89b74/tokenizer_config.json#L51).

If this model was trained with legacy=True, can the inference perform better with legacy=False?
If so, it is just a small change in the json file (maybe not so simple implementation in Llama.cpp).

I see the following approaches to achieve that:
Record it in the model file, whether spaces should be inserted. tokenizer.ggml.add_space_prefix may be used for that. This would probably require specifying it manually during conversion, preferably with testing of both options to see which works better with the model.

I think this is the way, but you only require manual editing if you want override model training config values.

Actually we have only this tokenizer flags implemented:

llama.cpp/src/llama.cpp

Lines 2608 to 2615 in f1948f1

bool tokenizer_add_space_prefix = false;
bool tokenizer_add_bos = false;
bool tokenizer_add_eos = false;
bool tokenizer_ignore_merges = false;
bool tokenizer_clean_spaces = false; // clean_up_tokenization_spaces
bool tokenizer_remove_extra_whitespaces = false;
bool tokenizer_escape_whitespaces = true;
bool tokenizer_treat_whitespace_as_suffix = false;

Can you achieve what you expect using this flags? or adding another config json flag?
Maybe implementing a function like llama_config_tokenizer(...) to override this tokenizer flags.

If we add a non standard flag, how can we test it for all models?
If the pourpose is to detokenize matching the original input, testing is trivial.
I'm tempted to add some flag like detokenizer_replicate_original, but not sure how much work will require a general implementacion.

@shibe2
Copy link
Contributor

shibe2 commented Jul 7, 2024

With regard to small details like this, model’s configuration files can can contain errors and sub-optimal parameters. For any particular model, I’m not sure what exact tokenization quirks were in effect during training, so I would not be able to find an example of a model that performs better with different rules for inserting spaces between training and inference. But I can give a theoretical example where trying to match training parameters doesn’t make sense: merged model where source models used different modes. I think, such model should work well both with and without extra spaces, but still, which mode is better can only be found out via testing. And here, I suspect, different quantization types may favor different options.

If we can store in model file all the information needed to reproduce the desired variation of tokenization, then in some cases, that information may need to be written after quantization is done. Given that who creates, converts, and uses a model may be 3 different entities, it’s good to give the user an option to easily select which variation to use.

As to how to store that information, I’m not sure what would be the best way, but here is one:

  • if tokenizer.ggml.add_space_prefix is false, never insert spaces;
  • if tokenizer.ggml.add_space_prefix is true, insert spaces only for the very first token of the whole prompt and after BOS;
  • add spaces to chat template as needed.

This allows for flexibility like adding a space after im_start, but not after im_end. I don’t know if any model would benefit from it, but if one does, it will be covered.

@shibe2
Copy link
Contributor

shibe2 commented Jul 7, 2024

As a side note, this issue concerns the case when parse_special=true, which is bad to begin with. It will convert the text like "<|im_start|>" and such inside message content into special tokens, which should not be done. To prepare the prompt properly, text parts should be tokenized separately, and so you should never need parse_special.

When #3664 will be resolved, if ever, and when a better way to handle special tokens in prompt formats/templates will be implemented (this is planned), this will need to be considered again.

@github-actions github-actions bot added the stale label Aug 7, 2024
Copy link
Contributor

This issue was closed because it has been inactive for 14 days since being marked as stale.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug-unconfirmed medium severity Used to report medium severity bugs in llama.cpp (e.g. Malfunctioning Features but still useable) stale
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants