Skip to content

Commit

Permalink
apply_chat_template: consistent behaviour for return_assistant_tokens…
Browse files Browse the repository at this point in the history
…_mask=True return_tensors=True (#35582)

* apply_chat_template: consistent return_tensors behaviour with return_assistant_tokens_mask flag

* test_chat_template_return_assistant_tokens_mask: support tokenizers with no attention mask

* test_chat_template_return_assistant_tokens_mask: skip tokenizers with no padding token

* test_chat_template_return_assistant_tokens_mask: force tokenizer padding_side=right

---------

Co-authored-by: Eduard Allakhverdov <goncharova@airi.net>
Co-authored-by: d.tarasov <d.tarasov@airi.net>
  • Loading branch information
3 people authored Feb 4, 2025
1 parent 9c02cb6 commit 2ba040a
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1742,7 +1742,15 @@ def apply_chat_template(
for token_id in range(start_token, end_token + 1 if end_token else len(input_ids[i])):
current_mask[token_id] = 1
assistant_masks.append(current_mask)
out["assistant_masks"] = assistant_masks if is_batched else assistant_masks[0]

if not is_batched and not return_tensors:
assistant_masks = assistant_masks[0]

out["assistant_masks"] = assistant_masks

if return_tensors:
out.convert_to_tensors(tensor_type=return_tensors)

return out
else:
return out["input_ids"]
Expand Down
52 changes: 52 additions & 0 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@


if is_torch_available():
import torch
import torch.nn as nn


Expand Down Expand Up @@ -1219,6 +1220,7 @@ def test_jinja_strftime(self):
self.assertEqual(len(strftime_output), 10)
self.assertEqual(len(strftime_output.split("-")), 3)

@require_torch
@require_jinja
def test_chat_template_return_assistant_tokens_mask(self):
dummy_template = (
Expand Down Expand Up @@ -1263,6 +1265,9 @@ def test_chat_template_return_assistant_tokens_mask(self):
self.skipTest(reason="No fast tokenizer defined")

tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name)
self._check_no_pad_token_padding(tokenizer_r, conversations)

tokenizer_r.padding_side = "right"

# check batched
output = tokenizer_r.apply_chat_template(
Expand All @@ -1272,6 +1277,20 @@ def test_chat_template_return_assistant_tokens_mask(self):
return_assistant_tokens_mask=True,
return_dict=True,
)

output_pt = tokenizer_r.apply_chat_template(
conversations,
chat_template=dummy_template,
tokenize=True,
padding=True,
return_assistant_tokens_mask=True,
return_dict=True,
return_tensors="pt",
)

self.assertEqual(type(output_pt["assistant_masks"]), torch.Tensor)
self.assertEqual(output_pt["assistant_masks"].shape, output_pt["input_ids"].shape)

for i, conv in enumerate(conversations):
chat_string = tokenizer_r.apply_chat_template(
conversations[i], tokenize=False, chat_template=dummy_template
Expand All @@ -1297,18 +1316,30 @@ def test_chat_template_return_assistant_tokens_mask(self):
output["assistant_masks"][i][assistant_start : assistant_end + 1],
[1] * (assistant_end - assistant_start + 1),
)
self.assertTrue(
(output_pt["assistant_masks"][i, assistant_start : assistant_end + 1] == 1).all(),
)

# assert 1 second assistant message
self.assertEqual(
output["assistant_masks"][i][assistant_start2 : assistant_end2 + 1],
[1] * (assistant_end2 - assistant_start2 + 1),
)
self.assertTrue(
(output_pt["assistant_masks"][i, assistant_start2 : assistant_end2 + 1] == 1).all(),
)

# assert 0 in user/system indices
self.assertEqual(output["assistant_masks"][i][:assistant_start], [0] * assistant_start)
self.assertTrue((output_pt["assistant_masks"][i, :assistant_start] == 0).all())

self.assertEqual(
output["assistant_masks"][i][assistant_end + 1 : assistant_start2],
[0] * (assistant_start2 - assistant_end - 1),
)
self.assertTrue(
(output_pt["assistant_masks"][i, assistant_end + 1 : assistant_start2] == 0).all(),
)

# check not batched
output = tokenizer_r.apply_chat_template(
Expand All @@ -1318,6 +1349,17 @@ def test_chat_template_return_assistant_tokens_mask(self):
return_assistant_tokens_mask=True,
return_dict=True,
)
output_pt = tokenizer_r.apply_chat_template(
conversations[0],
chat_template=dummy_template,
tokenize=True,
return_assistant_tokens_mask=True,
return_dict=True,
return_tensors="pt",
)

self.assertEqual(type(output_pt["assistant_masks"]), torch.Tensor)
self.assertEqual(output_pt["assistant_masks"].shape, output_pt["input_ids"].shape)

chat_string = tokenizer_r.apply_chat_template(
conversations[0], tokenize=False, chat_template=dummy_template
Expand All @@ -1336,17 +1378,27 @@ def test_chat_template_return_assistant_tokens_mask(self):
output["assistant_masks"][assistant_start : assistant_end + 1],
[1] * (assistant_end - assistant_start + 1),
)
self.assertTrue(
(output_pt["assistant_masks"][assistant_start : assistant_end + 1] == 1).all(),
)
self.assertEqual(
output["assistant_masks"][assistant_start2 : assistant_end2 + 1],
[1] * (assistant_end2 - assistant_start2 + 1),
)
self.assertTrue(
(output_pt["assistant_masks"][assistant_start2 : assistant_end2 + 1] == 1).all(),
)

# assert 0 in user/system indices
self.assertEqual(output["assistant_masks"][:assistant_start], [0] * assistant_start)
self.assertTrue((output_pt["assistant_masks"][0, :assistant_start] == 0).all())
self.assertEqual(
output["assistant_masks"][assistant_end + 1 : assistant_start2],
[0] * (assistant_start2 - assistant_end - 1),
)
self.assertTrue(
(output_pt["assistant_masks"][0, assistant_end + 1 : assistant_start2] == 0).all(),
)

@require_jinja
def test_chat_template_return_assistant_tokens_mask_truncated(self):
Expand Down

0 comments on commit 2ba040a

Please sign in to comment.