-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🚨🚨[Whisper Tok] Update integration test #29368
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
|
||
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast | ||
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence | ||
from transformers.testing_utils import require_jinja, slow | ||
from transformers.testing_utils import slow | ||
|
||
from ...test_tokenization_common import TokenizerTesterMixin | ||
|
||
|
@@ -67,26 +67,26 @@ def test_full_tokenizer(self): | |
tokenizer = WhisperTokenizer.from_pretrained(self.tmpdirname) | ||
|
||
tokens = tokenizer.tokenize("This is a test") | ||
self.assertListEqual(tokens, ["This", "Ġis", "Ġa", "Ġ", "test"]) | ||
self.assertListEqual(tokens, ["This", "Ġis", "Ġa", "Ġtest"]) | ||
|
||
self.assertListEqual( | ||
tokenizer.convert_tokens_to_ids(tokens), | ||
[5723, 307, 257, 220, 31636], | ||
[5723, 307, 257, 1500], | ||
) | ||
|
||
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.") | ||
self.assertListEqual( | ||
tokens, | ||
["I", "Ġwas", "Ġborn", "Ġin", "Ġ9", "2000", ",", "Ġand", "Ġ", "this", "Ġis", "Ġfals", "é", "."], # fmt: skip | ||
) # fmt: skip | ||
["I", "Ġwas", "Ġborn", "Ġin", "Ġ9", "2000", ",", "Ġand", "Ġthis", "Ġis", "Ġfals", "é", "."], # fmt: skip | ||
) | ||
ids = tokenizer.convert_tokens_to_ids(tokens) | ||
self.assertListEqual(ids, [40, 390, 4232, 294, 1722, 25743, 11, 293, 220, 11176, 307, 16720, 526, 13]) | ||
self.assertListEqual(ids, [40, 390, 4232, 294, 1722, 25743, 11, 293, 341, 307, 16720, 526, 13]) | ||
|
||
back_tokens = tokenizer.convert_ids_to_tokens(ids) | ||
self.assertListEqual( | ||
back_tokens, | ||
["I", "Ġwas", "Ġborn", "Ġin", "Ġ9", "2000", ",", "Ġand", "Ġ", "this", "Ġis", "Ġfals", "é", "."], # fmt: skip | ||
) # fmt: skip | ||
["I", "Ġwas", "Ġborn", "Ġin", "Ġ9", "2000", ",", "Ġand", "Ġthis", "Ġis", "Ġfals", "é", "."], # fmt: skip | ||
) | ||
|
||
def test_tokenizer_slow_store_full_signature(self): | ||
pass | ||
|
@@ -499,25 +499,3 @@ def test_offset_decoding(self): | |
|
||
output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"] | ||
self.assertEqual(output, []) | ||
|
||
@require_jinja | ||
def test_tokenization_for_chat(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Chat template doesn't make sense for Whisper (a speech recognition model) - have removed the test to keep the CI lightweight (cc @Rocketknight1) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fine with me! |
||
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny") | ||
# This is in English, but it's just here to make sure the chat control tokens are being added properly | ||
test_chats = [ | ||
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}], | ||
[ | ||
{"role": "system", "content": "You are a helpful chatbot."}, | ||
{"role": "user", "content": "Hello!"}, | ||
{"role": "assistant", "content": "Nice to meet you."}, | ||
], | ||
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}], | ||
] | ||
tokenized_chats = [multilingual_tokenizer.apply_chat_template(test_chat) for test_chat in test_chats] | ||
expected_tokens = [ | ||
[3223, 366, 257, 4961, 5081, 18870, 13, 50257, 15947, 0, 50257], | ||
[3223, 366, 257, 4961, 5081, 18870, 13, 50257, 15947, 0, 50257, 37717, 220, 1353, 1677, 291, 13, 50257], | ||
[37717, 220, 1353, 1677, 291, 13, 50257, 15947, 0, 50257], | ||
] | ||
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens): | ||
self.assertListEqual(tokenized_chat, expected_tokens) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This now gives equivalent results to the original:
Print Output: