Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚨🚨[Whisper Tok] Update integration test #29368

Merged
merged 2 commits into from
Mar 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 8 additions & 30 deletions tests/models/whisper/test_tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
from transformers.testing_utils import require_jinja, slow
from transformers.testing_utils import slow

from ...test_tokenization_common import TokenizerTesterMixin

Expand Down Expand Up @@ -67,26 +67,26 @@ def test_full_tokenizer(self):
tokenizer = WhisperTokenizer.from_pretrained(self.tmpdirname)

tokens = tokenizer.tokenize("This is a test")
self.assertListEqual(tokens, ["This", "Ġis", "Ġa", "Ġ", "test"])
self.assertListEqual(tokens, ["This", "Ġis", "Ġa", "Ġtest"])

self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens),
[5723, 307, 257, 220, 31636],
[5723, 307, 257, 1500],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This now gives equivalent results to the original:

from whisper.tokenizer import get_tokenizer

tokenizer = get_tokenizer(True)
tokens = tokenizer.encode("This is a test")
print(tokens)

Print Output:

[5723, 307, 257, 1500]

)

tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(
tokens,
["I", "Ġwas", "Ġborn", "Ġin", "Ġ9", "2000", ",", "Ġand", "Ġ", "this", "Ġis", "Ġfals", "é", "."], # fmt: skip
) # fmt: skip
["I", "Ġwas", "Ġborn", "Ġin", "Ġ9", "2000", ",", "Ġand", "Ġthis", "Ġis", "Ġfals", "é", "."], # fmt: skip
)
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(ids, [40, 390, 4232, 294, 1722, 25743, 11, 293, 220, 11176, 307, 16720, 526, 13])
self.assertListEqual(ids, [40, 390, 4232, 294, 1722, 25743, 11, 293, 341, 307, 16720, 526, 13])

back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(
back_tokens,
["I", "Ġwas", "Ġborn", "Ġin", "Ġ9", "2000", ",", "Ġand", "Ġ", "this", "Ġis", "Ġfals", "é", "."], # fmt: skip
) # fmt: skip
["I", "Ġwas", "Ġborn", "Ġin", "Ġ9", "2000", ",", "Ġand", "Ġthis", "Ġis", "Ġfals", "é", "."], # fmt: skip
)

def test_tokenizer_slow_store_full_signature(self):
pass
Expand Down Expand Up @@ -499,25 +499,3 @@ def test_offset_decoding(self):

output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
self.assertEqual(output, [])

@require_jinja
def test_tokenization_for_chat(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chat template doesn't make sense for Whisper (a speech recognition model) - have removed the test to keep the CI lightweight (cc @Rocketknight1)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine with me!

multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
# This is in English, but it's just here to make sure the chat control tokens are being added properly
test_chats = [
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
[
{"role": "system", "content": "You are a helpful chatbot."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Nice to meet you."},
],
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
]
tokenized_chats = [multilingual_tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
expected_tokens = [
[3223, 366, 257, 4961, 5081, 18870, 13, 50257, 15947, 0, 50257],
[3223, 366, 257, 4961, 5081, 18870, 13, 50257, 15947, 0, 50257, 37717, 220, 1353, 1677, 291, 13, 50257],
[37717, 220, 1353, 1677, 291, 13, 50257, 15947, 0, 50257],
]
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
self.assertListEqual(tokenized_chat, expected_tokens)
Loading