-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Return assistant generated tokens mask in apply_chat_template #30650
Return assistant generated tokens mask in apply_chat_template #30650
Conversation
c14a7f6
to
4f77aca
Compare
My thoughts on your questions:
I prefer just returning the labels with masking applied, rather than returning the mask for the user to apply.
I think
Agreed! I guess the easiest way to check this is to just do a string search for
Yes, there's already a DataCollatorForCompletionOnlyLM which also requires I want to hear from @xenova and ideally someone using minijinja as well, though - how easily can we support this extension? Since it's only useful in training, maybe it's less critical to have it in |
I agree, but then what should be the ignore label? -100 (pytorch)?. Im not sure its a good idea to add another parameter |
I think |
yea i just thought of non pytorch users where -100 is not the default. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@yonigottesman Thanks for working on this, this is a feature I am very much looking forward to. Hope this can be merged soon. |
Yes, sorry for not checking in @yonigottesman! Do you have other features you want to add, or should we treat this as review-ready? |
@Rocketknight1 this is ready to be reviewed yes :) |
On it! |
@yonigottesman while I'm reviewing can you rebase/resolve the merge conflict? It's nothing major, but it'll block us merging the PR until it's ready. (Edit: Probably better to rebase because your branch is a little out of date by now, a rebase will catch any other issues before merging) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just did a review! Here's my comments:
- Overall looks good, and it's a really clever solution
- Performance is very good, the test runs in milliseconds on my machine
- Should we add support for this to one or more of the DataCollator classes?
- Some changes should be reverted, see the specific code comments
Finally, I'm not sure if we should be returning labels
or an assistant_mask
from apply_chat_template()
. I think it makes sense to return masked labels from a data collator that supports this, but not from apply_chat_template()
itself, because it's kind of weird for apply_chat_template()
to be handling labels at all! I think it might be better if apply_chat_template()
just returns a simpler mask, and then the data collators use that to do masking.
494c95e
to
d698ec7
Compare
I agree it should be |
Yes, agree! It's also fine to leave that for a separate PR, and just add the mask functionality in this PR. |
ok. fixed to now return mask |
a454b39
to
7bd0140
Compare
Got it! Ping me whenever you're ready for re-review. |
ready 😀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks really good now! I made a couple of small suggestions, but I think we're ready for core maintainer review now, cc @amyeroberts. There are also some failing tests, but these are unrelated, and should be fixed if you rebase.
Also, to make Amy's job easier, a quick explanation: This PR allows chat templates to mark assistant generations in the template. Very often, training pipelines only want to train on those tokens, and not compute loss on other tokens (e.g. control tokens, user messages, system messages).
The way it works is by adding a small Jinja extension to support {% generation %}
blocks, and then combining the string offsets from these blocks with the string offsets from tokenization to create a mask array, which is included as one of the tokenization outputs alongside input_ids
and attention_mask
.
return self._rendered_blocks or self._generation_indices | ||
|
||
@contextmanager | ||
def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In mlflow/mlflow#12757, we found this line throws in python 3.8.
class AssistantTracker(Extension):
# This extension is used to track the indices of assistant-generated tokens in the rendered chat
tags = {"generation"}
def __init__(self, environment: ImmutableSandboxedEnvironment):
# The class is only initiated by jinja.
super().__init__(environment)
environment.extend(activate_tracker=self.activate_tracker)
self._rendered_blocks = None
self._generation_indices = None
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
lineno = next(parser.stream).lineno
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
return nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)
@jinja2.pass_eval_context
def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str:
rv = caller()
if self.is_active():
# Only track generation indices if the tracker is active
start_index = len("".join(self._rendered_blocks))
end_index = start_index + len(rv)
self._generation_indices.append((start_index, end_index))
return rv
def is_active(self) -> bool:
return self._rendered_blocks or self._generation_indices
@contextmanager
> def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]):
E TypeError: 'type' object is not subscriptable
__init__ = <function PreTrainedTokenizerBase._compile_jinja_template.<locals>.AssistantTracker.__init__ at 0x7f013dc78940>
__module__ = 'transformers.tokenization_utils_base'
__qualname__ = 'PreTrainedTokenizerBase._compile_jinja_template.<locals>.AssistantTracker'
_generation_support = <function PreTrainedTokenizerBase._compile_jinja_template.<locals>.AssistantTracker._generation_support at 0x7f013dc78790>
is_active = <function PreTrainedTokenizerBase._compile_jinja_template.<locals>.AssistantTracker.is_active at 0x7f013dc785e0>
parse = <function PreTrainedTokenizerBase._compile_jinja_template.<locals>.AssistantTracker.parse at 0x7f013dc78820>
tags = {'generation'}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]): | |
def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[int]): |
or from __future__ import annotations
needs to be added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for flagging! Opening a PR to fix
https://github.com/huggingface/transformers/pull/32155/files
Amazing contribution! 🎉 🎉 🎉 It helps me a lot!
Some spelling mistakes in this example:
|
Thank you for your work on this! I'm having some issues though. When I run the example script from the tests, I don't seem to get any assistant tokens: from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
dummy_template = (
"{% for message in messages %}"
"{% if (message['role'] != 'assistant') %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% elif (message['role'] == 'assistant')%}"
"{{'<|im_start|>' + message['role'] + '\n'}}"
"{% generation %}"
"{{message['content'] + '<|im_end|>'}}"
"{% endgeneration %}"
"{{'\n'}}"
"{% endif %}"
"{% endfor %}"
)
conversations = [
[
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "start turn 1 assistant message. end turn 1"},
{"role": "user", "content": "user message 2"},
{"role": "assistant", "content": "start turn 2 assistant message. end turn 2"},
],
[
{"role": "system", "content": "system message 3"},
{"role": "user", "content": "user message 3"},
{"role": "assistant", "content": "start turn 3 assistant message. end turn 3"},
{"role": "user", "content": "user message 4"},
{"role": "assistant", "content": "start turn 4 assistant message. end turn 4"},
],
]
output = tokenizer.apply_chat_template(
conversations[0],
chat_template=dummy_template,
tokenize=True,
return_assistant_tokens_mask=True,
return_dict=True,
)
print("".join(map(str, output["assistant_masks"]))) For me, this prints out I think this bug is being caused by other tokens being printed out before the dummy_template = (
"{% for message in messages %}"
"{% if (message['role'] != 'assistant') %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% elif (message['role'] == 'assistant')%}"
"{% generation %}"
"{{'<|im_start|>' + message['role'] + '\n'}}"
"{{message['content'] + '<|im_end|>'}}"
"{% endgeneration %}"
"{{'\n'}}"
"{% endif %}"
"{% endfor %}"
) it works correctly, printing out I am running the latest version of transformers, 4.44.0 |
@avicooper1 there is a bug but its not about tokens before "generation" in the same turn. If you try a different tokenizer it will work. |
Given the issue, is there a workaround to get the assistant mask? |
sadly for llama3 i dont think so :( |
Unfortunately the fact that the template needs to contain the |
@psinger there were several issues where some solutions were discussed but nothing was flexible enough as every model can have its own chat template and special tokens. see #28950 and #27609. |
I can also confirm that there's something strange about llama3's tokenizer such that it just can't work regardless of the template. |
So in that case I should manually add {% generation %} for each tokenizer's template? |
Hi! I could confirm this works on Mistral7B tokenizer, but doesn't work on any of the Llama tokenizers (tried 3 and 3.1 LlamaTokenizer). |
Same here, tried 3.2 model as well. |
@thepowerfuldeez @kwanUm can you update the |
Still doesn't work @yonigottesman, appereantly I need to add the {% generation %} tag manually? is there an existing fucntionality for it in hf or somewhere else? Here's a reproduction
With default chat template:
pip list | grep -E "token|transfor|accele|sentence" accelerate 1.0.1
|
until tokenizers start adding the |
The following chat template is what I am using for Llama 3.2 1B and 3B models. While there are a few differences between Llama 3.1 and 3.2 chat templates, one can use this prompt as a reference point. LLAMA32_CHAT_TEMPLATE = """{{- bos_token }}
{%- if custom_tools is defined %}
{%- set tools = custom_tools %}
{%- endif %}
{%- if not tools_in_user_message is defined %}
{%- set tools_in_user_message = true %}
{%- endif %}
{%- if not date_string is defined %}
{%- if strftime_now is defined %}
{%- set date_string = strftime_now("%d %b %Y") %}
{%- else %}
{%- set date_string = "26 Jul 2024" %}
{%- endif %}
{%- endif %}
{%- if not tools is defined %}
{%- set tools = none %}
{%- endif %}
{#- This block extracts the system message, so we can slot it into the right place. #}
{%- if messages[0]['role'] == 'system' %}
{%- set system_message = messages[0]['content']|trim %}
{%- set messages = messages[1:] %}
{%- else %}
{%- set system_message = "" %}
{%- endif %}
{#- System message #}
{{- "<|start_header_id|>system<|end_header_id|>\n\n" }}
{%- if tools is not none %}
{{- "Environment: ipython\n" }}
{%- endif %}
{{- "Cutting Knowledge Date: December 2023\n" }}
{{- "Today Date: " + date_string + "\n\n" }}
{%- if tools is not none and not tools_in_user_message %}
{{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}
{{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
{{- "Do not use variables.\n\n" }}
{%- for t in tools %}
{{- t | tojson(indent=4) }}
{{- "\n\n" }}
{%- endfor %}
{%- endif %}
{{- system_message }}
{{- "<|eot_id|>" }}
{#- Custom tools are passed in a user message with some extra guidance #}
{%- if tools_in_user_message and not tools is none %}
{#- Extract the first user message so we can plug it in here #}
{%- if messages | length != 0 %}
{%- set first_user_message = messages[0]['content']|trim %}
{%- set messages = messages[1:] %}
{%- else %}
{{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
{%- endif %}
{{- '<|start_header_id|>user<|end_header_id|>\n\n' -}}
{{- "Given the following functions, please respond with a JSON for a function call " }}
{{- "with its proper arguments that best answers the given prompt.\n\n" }}
{{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
{{- "Do not use variables.\n\n" }}
{%- for t in tools %}
{{- t | tojson(indent=4) }}
{{- "\n\n" }}
{%- endfor %}
{{- first_user_message + "<|eot_id|>"}}
{%- endif %}
{%- for message in messages %}
{%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
{%- if message.role != 'assistant' %}
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }}
{%- elif message.role == 'assistant' %}
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'}}
{% generation %}
{{- message['content'] | trim + '<|eot_id|>' }}
{% endgeneration %}
{%- endif %}
{%- elif 'tool_calls' in message %}
{%- if not message.tool_calls|length == 1 %}
{{- raise_exception("This model only supports single tool-calls at once!") }}
{%- endif %}
{%- set tool_call = message.tool_calls[0].function %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
{{- '{"name": "' + tool_call.name + '", ' }}
{{- '"parameters": ' }}
{{- tool_call.arguments | tojson }}
{{- "}" }}
{{- "<|eot_id|>" }}
{%- elif message.role == "tool" or message.role == "ipython" %}
{{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }}
{%- if message.content is mapping or message.content is iterable %}
{{- message.content | tojson }}
{%- else %}
{{- message.content }}
{%- endif %}
{{- "<|eot_id|>" }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
{%- endif %}""" |
What does this PR do?
This PR addresses issue #28950 and enhances the functionality of the
tokenizer.apply_chat_template
method when finetuning on chat datasets.The method tokenizer.apply_chat_template is recommended for maintaining consistency with the model's original template during both training and inference phases. This practice ensures that conversations are processed in a uniform manner.
Moreover, during the finetuning process on chat datasets, it is crucial to exclude tokens from the "user" or "system" segments of the conversation. This exclusion is necessary because including these tokens would train the model to predict not only the "assistant" responses but also potential user queries, which is undesirable (and strange).
Currently, the
tokenizer.apply_chat_template
method does not provide a way to identify which tokens belong to the "assistant" response. To address this, the PR introduces a new parameter calledreturn_assistant_mask
. This parameter returns a mask that identifies tokens generated by the assistant, allowing for the appropriate creation of a labels arrays with ignore (-100) values during training.Additionally, this PR proposes the introduction of a new keyword
generation
(name open for discussion) in the jinja2 chat template. This keyword is used to encapsulate the assistant’s response within your chat template.Here is an example of the new api:
There are some issues I would want to discuss during this pr:
labels
key in the dict already and not bother with the intermediate mask.generation
but maybe should beassistant_response
? or anything you like.return_assistant_mask
but the tokenizer chat template hasn't changed yet to support this new tag. That way users will know the are probably training on wrong tokens.packing=True
. My new changes wont be usable easily if people use that parameter and maybe we should think of my API while taking into consideration a refactor of thepacking
affect.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.