Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return assistant generated tokens mask in apply_chat_template #30650

Merged

Conversation

yonigottesman
Copy link
Contributor

What does this PR do?

This PR addresses issue #28950 and enhances the functionality of the tokenizer.apply_chat_template method when finetuning on chat datasets.

The method tokenizer.apply_chat_template is recommended for maintaining consistency with the model's original template during both training and inference phases. This practice ensures that conversations are processed in a uniform manner.

Moreover, during the finetuning process on chat datasets, it is crucial to exclude tokens from the "user" or "system" segments of the conversation. This exclusion is necessary because including these tokens would train the model to predict not only the "assistant" responses but also potential user queries, which is undesirable (and strange).

Currently, the tokenizer.apply_chat_template method does not provide a way to identify which tokens belong to the "assistant" response. To address this, the PR introduces a new parameter called return_assistant_mask. This parameter returns a mask that identifies tokens generated by the assistant, allowing for the appropriate creation of a labels arrays with ignore (-100) values during training.

Additionally, this PR proposes the introduction of a new keyword generation (name open for discussion) in the jinja2 chat template. This keyword is used to encapsulate the assistant’s response within your chat template.

Here is an example of the new api:

template = (
    "{% for message in messages %}"
    "{% if (message['role'] != 'assistant') %}"
    "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
    "{% elif (message['role'] == 'assistant')%}"
    "{{'<|im_start|>' + message['role'] + '\n'}}"
    "{% generation %}"
    "{{message['content'] + '<|im_end|>'}}"
    "{% endgeneration %}"
    "{{'\n'}}"
    "{% endif %}"
    "{% endfor %}"
)
dummy_conversation = [
      {"role": "system", "content": "system message"},
      {"role": "user", "content": "user message"},
      {"role": "assistant", "content": "assistant\nmessage"},
      {"role": "user", "content": "user message 2"},
      {"role": "assistant", "content": "assistant message 2"},
]

output = tokenizer_r.apply_chat_template(
    dummy_conversations,
    chat_template=dummy_template,
    tokenize=True,
    return_assistant_mask=True,
    return_dict=True,
  )

labels = [output["input_ids"][index] if mask == 1 else -100 for index, mask in enumerate(output["assistant_mask"])]

There are some issues I would want to discuss during this pr:

  • Is this API fine? maybe we should return the a labels key in the dict already and not bother with the intermediate mask.
  • Name of the new tag? currently generation but maybe should be assistant_response? or anything you like.
  • I think maybe I should add a warning if a user runs with return_assistant_mask but the tokenizer chat template hasn't changed yet to support this new tag. That way users will know the are probably training on wrong tokens.
  • In 99% of finetuning examples I see people using the trl trainer with packing=True. My new changes wont be usable easily if people use that parameter and maybe we should think of my API while taking into consideration a refactor of the packing affect.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@yonigottesman yonigottesman marked this pull request as draft May 4, 2024 09:53
@yonigottesman yonigottesman marked this pull request as ready for review May 4, 2024 09:53
@yonigottesman yonigottesman force-pushed the apply-chat-template-assistant-mask branch 3 times, most recently from c14a7f6 to 4f77aca Compare May 5, 2024 08:44
@LysandreJik LysandreJik requested a review from Rocketknight1 May 6, 2024 09:40
@Rocketknight1
Copy link
Member

cc @lewtun and @xenova to this as well!

@Rocketknight1
Copy link
Member

My thoughts on your questions:

Is this API fine? maybe we should return the a labels key in the dict already and not bother with the intermediate mask.

I prefer just returning the labels with masking applied, rather than returning the mask for the user to apply.

Name of the new tag? currently generation but maybe should be assistant_response? or anything you like.

I think generation is fine - assistant_response is very long!

I think maybe I should add a warning if a user runs with return_assistant_mask but the tokenizer chat template hasn't changed yet to support this new tag. That way users will know the are probably training on wrong tokens.

Agreed! I guess the easiest way to check this is to just do a string search for {% generation %} tags? Be careful, because you'll also need to check for variants like {-

n 99% of finetuning examples I see people using the trl trainer with packing=True. My new changes wont be usable easily if people use that parameter and maybe we should think of my API while taking into consideration a refactor of the packing affect.

Yes, there's already a DataCollatorForCompletionOnlyLM which also requires packing=False. I feel like we can slot in with that easily enough!

I want to hear from @xenova and ideally someone using minijinja as well, though - how easily can we support this extension? Since it's only useful in training, maybe it's less critical to have it in huggingface/jinja or TGI, but at the very least we should be able to gracefully ignore the generation tags.

@yonigottesman
Copy link
Contributor Author

I prefer just returning the labels with masking applied, rather than returning the mask for the user to apply.

I agree, but then what should be the ignore label? -100 (pytorch)?. Im not sure its a good idea to add another parameter ignore_label

@Rocketknight1
Copy link
Member

I think -100 is correct, yes! This is the standard value for Torch and Transformers, so we don't need an extra arg to change it.

@yonigottesman
Copy link
Contributor Author

yea i just thought of non pytorch users where -100 is not the default.
Anyways I updated the code to return labels

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@psinger
Copy link

psinger commented Jun 3, 2024

@yonigottesman Thanks for working on this, this is a feature I am very much looking forward to. Hope this can be merged soon.

@Rocketknight1
Copy link
Member

Yes, sorry for not checking in @yonigottesman! Do you have other features you want to add, or should we treat this as review-ready?

@yonigottesman
Copy link
Contributor Author

@Rocketknight1 this is ready to be reviewed yes :)

@Rocketknight1
Copy link
Member

On it!

@Rocketknight1
Copy link
Member

Rocketknight1 commented Jun 7, 2024

@yonigottesman while I'm reviewing can you rebase/resolve the merge conflict? It's nothing major, but it'll block us merging the PR until it's ready. (Edit: Probably better to rebase because your branch is a little out of date by now, a rebase will catch any other issues before merging)

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just did a review! Here's my comments:

  • Overall looks good, and it's a really clever solution
  • Performance is very good, the test runs in milliseconds on my machine
  • Should we add support for this to one or more of the DataCollator classes?
  • Some changes should be reverted, see the specific code comments

Finally, I'm not sure if we should be returning labels or an assistant_mask from apply_chat_template(). I think it makes sense to return masked labels from a data collator that supports this, but not from apply_chat_template() itself, because it's kind of weird for apply_chat_template() to be handling labels at all! I think it might be better if apply_chat_template() just returns a simpler mask, and then the data collators use that to do masking.

@yonigottesman yonigottesman force-pushed the apply-chat-template-assistant-mask branch from 494c95e to d698ec7 Compare June 10, 2024 09:06
@yonigottesman
Copy link
Contributor Author

I agree it should be assistant_mask and not labels. I feel like the collator should be added here and not trl what do you think?

@Rocketknight1
Copy link
Member

Yes, agree! It's also fine to leave that for a separate PR, and just add the mask functionality in this PR.

@yonigottesman
Copy link
Contributor Author

ok. fixed to now return mask

@yonigottesman yonigottesman force-pushed the apply-chat-template-assistant-mask branch 2 times, most recently from a454b39 to 7bd0140 Compare June 14, 2024 06:04
@Rocketknight1
Copy link
Member

Got it! Ping me whenever you're ready for re-review.

@yonigottesman
Copy link
Contributor Author

ready 😀

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks really good now! I made a couple of small suggestions, but I think we're ready for core maintainer review now, cc @amyeroberts. There are also some failing tests, but these are unrelated, and should be fixed if you rebase.

Also, to make Amy's job easier, a quick explanation: This PR allows chat templates to mark assistant generations in the template. Very often, training pipelines only want to train on those tokens, and not compute loss on other tokens (e.g. control tokens, user messages, system messages).

The way it works is by adding a small Jinja extension to support {% generation %} blocks, and then combining the string offsets from these blocks with the string offsets from tokenization to create a mask array, which is included as one of the tokenization outputs alongside input_ids and attention_mask.

@yonigottesman yonigottesman deleted the apply-chat-template-assistant-mask branch July 22, 2024 17:25
return self._rendered_blocks or self._generation_indices

@contextmanager
def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]):
Copy link
Contributor

@harupy harupy Jul 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts @yonigottesman

In mlflow/mlflow#12757, we found this line throws in python 3.8.

https://github.com/mlflow/mlflow/actions/runs/10056412801/job/27795200814?pr=12757#step:12:1016

    class AssistantTracker(Extension):
        # This extension is used to track the indices of assistant-generated tokens in the rendered chat
        tags = {"generation"}
    
        def __init__(self, environment: ImmutableSandboxedEnvironment):
            # The class is only initiated by jinja.
            super().__init__(environment)
            environment.extend(activate_tracker=self.activate_tracker)
            self._rendered_blocks = None
            self._generation_indices = None
    
        def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
            lineno = next(parser.stream).lineno
            body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
            return nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)
    
        @jinja2.pass_eval_context
        def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str:
            rv = caller()
            if self.is_active():
                # Only track generation indices if the tracker is active
                start_index = len("".join(self._rendered_blocks))
                end_index = start_index + len(rv)
                self._generation_indices.append((start_index, end_index))
            return rv
    
        def is_active(self) -> bool:
            return self._rendered_blocks or self._generation_indices
    
        @contextmanager
>       def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]):
E       TypeError: 'type' object is not subscriptable

__init__   = <function PreTrainedTokenizerBase._compile_jinja_template.<locals>.AssistantTracker.__init__ at 0x7f013dc78940>
__module__ = 'transformers.tokenization_utils_base'
__qualname__ = 'PreTrainedTokenizerBase._compile_jinja_template.<locals>.AssistantTracker'
_generation_support = <function PreTrainedTokenizerBase._compile_jinja_template.<locals>.AssistantTracker._generation_support at 0x7f013dc78790>
is_active  = <function PreTrainedTokenizerBase._compile_jinja_template.<locals>.AssistantTracker.is_active at 0x7f013dc785e0>
parse      = <function PreTrainedTokenizerBase._compile_jinja_template.<locals>.AssistantTracker.parse at 0x7f013dc78820>
tags       = {'generation'}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]):
def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[int]):

or from __future__ import annotations needs to be added.

Copy link
Collaborator

@amyeroberts amyeroberts Jul 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stceum
Copy link

stceum commented Aug 7, 2024

Amazing contribution! 🎉 🎉 🎉 It helps me a lot!

Here is an example of the new api:

template = (
    "{% for message in messages %}"
    "{% if (message['role'] != 'assistant') %}"
    "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
    "{% elif (message['role'] == 'assistant')%}"
    "{{'<|im_start|>' + message['role'] + '\n'}}"
    "{% generation %}"
    "{{message['content'] + '<|im_end|>'}}"
    "{% endgeneration %}"
    "{{'\n'}}"
    "{% endif %}"
    "{% endfor %}"
)
dummy_conversation = [
      {"role": "system", "content": "system message"},
      {"role": "user", "content": "user message"},
      {"role": "assistant", "content": "assistant\nmessage"},
      {"role": "user", "content": "user message 2"},
      {"role": "assistant", "content": "assistant message 2"},
]

output = tokenizer_r.apply_chat_template(
    dummy_conversations,
    chat_template=dummy_template,
    tokenize=True,
    return_assistant_mask=True,
    return_dict=True,
  )

labels = [output["input_ids"][index] if mask == 1 else -100 for index, mask in enumerate(output["assistant_mask"])]

Some spelling mistakes in this example:

output = tokenizer_r.apply_chat_template(
dummy_conversations,
chat_template=dummy_template,
tokenize=True,
return_assistant_mask=True,
return_dict=True,
)

dummy_conversation instead of dummy_conversations, template instead of dummy_template

labels = [output["input_ids"][index] if mask == 1 else -100 for index, mask in enumerate(output["assistant_mask"])]

assistant_masks instead of assistant_mask

@avicooper1
Copy link

Thank you for your work on this!

I'm having some issues though. When I run the example script from the tests, I don't seem to get any assistant tokens:

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")

dummy_template = (
    "{% for message in messages %}"
    "{% if (message['role'] != 'assistant') %}"
    "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
    "{% elif (message['role'] == 'assistant')%}"
    "{{'<|im_start|>' + message['role'] + '\n'}}"
    "{% generation %}"
    "{{message['content'] + '<|im_end|>'}}"
    "{% endgeneration %}"
    "{{'\n'}}"
    "{% endif %}"
    "{% endfor %}"
)
conversations = [
    [
        {"role": "system", "content": "system message"},
        {"role": "user", "content": "user message"},
        {"role": "assistant", "content": "start turn 1 assistant message. end turn 1"},
        {"role": "user", "content": "user message 2"},
        {"role": "assistant", "content": "start turn 2 assistant message. end turn 2"},
    ],
    [
        {"role": "system", "content": "system message 3"},
        {"role": "user", "content": "user message 3"},
        {"role": "assistant", "content": "start turn 3 assistant message. end turn 3"},
        {"role": "user", "content": "user message 4"},
        {"role": "assistant", "content": "start turn 4 assistant message. end turn 4"},
    ],
]

output = tokenizer.apply_chat_template(
    conversations[0],
    chat_template=dummy_template,
    tokenize=True,
    return_assistant_tokens_mask=True,
    return_dict=True,
)
print("".join(map(str, output["assistant_masks"])))

For me, this prints out 0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000

I think this bug is being caused by other tokens being printed out before the {% generation %}, within the same turn. For example, if I change the chat template to:

dummy_template = (
    "{% for message in messages %}"
    "{% if (message['role'] != 'assistant') %}"
    "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
    "{% elif (message['role'] == 'assistant')%}"
    "{% generation %}"
    "{{'<|im_start|>' + message['role'] + '\n'}}"
    "{{message['content'] + '<|im_end|>'}}"
    "{% endgeneration %}"
    "{{'\n'}}"
    "{% endif %}"
    "{% endfor %}"
)

it works correctly, printing out 0000000000000000000000000000000011111111111111111111111110000000000000000001111111111111111111111111

I am running the latest version of transformers, 4.44.0

@yonigottesman
Copy link
Contributor Author

@avicooper1 there is a bug but its not about tokens before "generation" in the same turn. If you try a different tokenizer it will work.
There is something strange about the llama3 tokenizer (PreTrainedTokenizerFast) for some reason the char_to_token function isn't working as expected and my implementation is based on its result.
I opened an issue huggingface/tokenizers#1620.

@Boltzmachine
Copy link

Given the issue, is there a workaround to get the assistant mask?

@yonigottesman
Copy link
Contributor Author

sadly for llama3 i dont think so :(
other models that use the same tokenizer class PreTrainedTokenizerFast (but different config) do work for example tiiuae/falcon-mamba-7b-instruct. so i guess its something specific to the llama3 configuration

@psinger
Copy link

psinger commented Aug 12, 2024

Unfortunately the fact that the template needs to contain the {% generation %} part makes it very inflexible to use. Would it be somehow possible to just generate the mask base on the provided user assistant inputs?

@yonigottesman
Copy link
Contributor Author

@psinger there were several issues where some solutions were discussed but nothing was flexible enough as every model can have its own chat template and special tokens. see #28950 and #27609.
having a {% generation %} keyword was the only thing i could come up with, if you have any better idea that could be great ill be happy to try and implement it

@zjysteven
Copy link

I can also confirm that there's something strange about llama3's tokenizer such that it just can't work regardless of the template.

@Boltzmachine
Copy link

So in that case I should manually add {% generation %} for each tokenizer's template?

@thepowerfuldeez
Copy link

Hi! I could confirm this works on Mistral7B tokenizer, but doesn't work on any of the Llama tokenizers (tried 3 and 3.1 LlamaTokenizer).

@kwanUm
Copy link

kwanUm commented Oct 30, 2024

Same here, tried 3.2 model as well.

@yonigottesman
Copy link
Contributor Author

@thepowerfuldeez @kwanUm can you update the tokenizers package to latest and check? there was a fix huggingface/tokenizers#1640 than should have fixed this issue.

@kwanUm
Copy link

kwanUm commented Oct 31, 2024

Still doesn't work @yonigottesman, appereantly I need to add the {% generation %} tag manually? is there an existing fucntionality for it in hf or somewhere else?

Here's a reproduction

from transformers import AutoTokenizer

# Load the tokenizer for LLaMA 3.1
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")

# Define the conversation
conversation = [
    {"role": "system", "content": "blablabla"},
    {"role": "user", "content": "blablabla"},
    {"role": "assistant", "content": "blablabla"},
]



# Apply the chat template with tokenization and assistant mask
output = tokenizer.apply_chat_template(
    conversation,
    tokenize=True,
    return_assistant_tokens_mask=True,
    return_dict=True,
)

# Print the assistant_mask
print("With default chat template:")
print(output['assistant_masks'])

# Define the chat template with the {% generation %} tag
chat_template = (
    "{% for message in messages %}"
    "{% if message['role'] != 'assistant' %}"
    "{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }}"
    "{% elif message['role'] == 'assistant' %}"
    "{{ '<|im_start|>' + message['role'] + '\n' }}"
    "{% generation %}"
    "{{ message['content'] + '<|im_end|>\n' }}"
    "{% endgeneration %}"
    "{% endif %}"
    "{% endfor %}"
)

# Apply the chat template with tokenization and assistant mask
output = tokenizer.apply_chat_template(
    conversation,
    chat_template=chat_template,
    tokenize=True,
    return_assistant_tokens_mask=True,
    return_dict=True,
)


# Print the assistant_mask
print("With custom chat template containing the {% generation %} tag:")
print(output['assistant_masks'])

Output:

With default chat template:
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
With custom chat template containing the {% generation %} tag:
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]



pip list | grep -E "token|transfor|accele|sentence"

accelerate 1.0.1
asttokens 2.4.1
tiktoken 0.8.0
tokenizers 0.20.1
transformers 4.46.0

@yonigottesman
Copy link
Contributor Author

until tokenizers start adding the generation token to the tokenizer_config.json, you will have to add this manually

@ospanbatyr
Copy link

The following chat template is what I am using for Llama 3.2 1B and 3B models. While there are a few differences between Llama 3.1 and 3.2 chat templates, one can use this prompt as a reference point.

LLAMA32_CHAT_TEMPLATE = """{{- bos_token }}
{%- if custom_tools is defined %}
    {%- set tools = custom_tools %}
{%- endif %}
{%- if not tools_in_user_message is defined %}
    {%- set tools_in_user_message = true %}
{%- endif %}
{%- if not date_string is defined %}
    {%- if strftime_now is defined %}
        {%- set date_string = strftime_now("%d %b %Y") %}
    {%- else %}
        {%- set date_string = "26 Jul 2024" %}
    {%- endif %}
{%- endif %}
{%- if not tools is defined %}
    {%- set tools = none %}
{%- endif %}

{#- This block extracts the system message, so we can slot it into the right place. #}
{%- if messages[0]['role'] == 'system' %}
    {%- set system_message = messages[0]['content']|trim %}
    {%- set messages = messages[1:] %}
{%- else %}
    {%- set system_message = "" %}
{%- endif %}

{#- System message #}
{{- "<|start_header_id|>system<|end_header_id|>\n\n" }}
{%- if tools is not none %}
    {{- "Environment: ipython\n" }}
{%- endif %}
{{- "Cutting Knowledge Date: December 2023\n" }}
{{- "Today Date: " + date_string + "\n\n" }}
{%- if tools is not none and not tools_in_user_message %}
    {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}
    {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
    {{- "Do not use variables.\n\n" }}
    {%- for t in tools %}
        {{- t | tojson(indent=4) }}
        {{- "\n\n" }}
    {%- endfor %}
{%- endif %}
{{- system_message }}
{{- "<|eot_id|>" }}

{#- Custom tools are passed in a user message with some extra guidance #}
{%- if tools_in_user_message and not tools is none %}
    {#- Extract the first user message so we can plug it in here #}
    {%- if messages | length != 0 %}
        {%- set first_user_message = messages[0]['content']|trim %}
        {%- set messages = messages[1:] %}
    {%- else %}
        {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
{%- endif %}
    {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}}
    {{- "Given the following functions, please respond with a JSON for a function call " }}
    {{- "with its proper arguments that best answers the given prompt.\n\n" }}
    {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
    {{- "Do not use variables.\n\n" }}
    {%- for t in tools %}
        {{- t | tojson(indent=4) }}
        {{- "\n\n" }}
    {%- endfor %}
    {{- first_user_message + "<|eot_id|>"}}
{%- endif %}

{%- for message in messages %}
    {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
            {%- if message.role  != 'assistant' %}
                  {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }}
            {%- elif message.role  == 'assistant' %}
                  {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'}}
                 {% generation %}
                 {{- message['content'] | trim + '<|eot_id|>' }}
                 {% endgeneration %}
           {%- endif %}
    {%- elif 'tool_calls' in message %}
        {%- if not message.tool_calls|length == 1 %}
            {{- raise_exception("This model only supports single tool-calls at once!") }}
        {%- endif %}
        {%- set tool_call = message.tool_calls[0].function %}
        {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
        {{- '{"name": "' + tool_call.name + '", ' }}
        {{- '"parameters": ' }}
        {{- tool_call.arguments | tojson }}
        {{- "}" }}
        {{- "<|eot_id|>" }}
    {%- elif message.role == "tool" or message.role == "ipython" %}
        {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }}
        {%- if message.content is mapping or message.content is iterable %}
            {{- message.content | tojson }}
        {%- else %}
            {{- message.content }}
        {%- endif %}
        {{- "<|eot_id|>" }}
    {%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
    {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
{%- endif %}"""

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.