Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix flax whisper tokenizer bug #33151

Merged
merged 16 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,8 @@ def _convert_to_list(token_ids):
token_ids = token_ids.cpu().numpy()
elif "tensorflow" in str(type(token_ids)):
token_ids = token_ids.numpy()
elif "jaxlib" in str(type(token_ids)):
token_ids = token_ids.tolist()
# now the token ids are either a numpy array, or a list of lists
if isinstance(token_ids, np.ndarray):
token_ids = token_ids.tolist()
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/whisper/tokenization_whisper_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,8 @@ def _convert_to_list(token_ids):
token_ids = token_ids.cpu().numpy()
elif "tensorflow" in str(type(token_ids)):
token_ids = token_ids.numpy()
elif "jaxlib" in str(type(token_ids)):
token_ids = token_ids.tolist()
# now the token ids are either a numpy array, or a list of lists
if isinstance(token_ids, np.ndarray):
token_ids = token_ids.tolist()
Expand Down
101 changes: 89 additions & 12 deletions tests/models/whisper/test_tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,16 @@
import numpy as np

from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
from transformers.testing_utils import slow
from transformers.models.whisper.tokenization_whisper import (
_combine_tokens_into_words,
_find_longest_common_sequence,
)
from transformers.testing_utils import (
require_flax,
require_tf,
require_torch,
slow,
)

from ...test_tokenization_common import TokenizerTesterMixin

Expand Down Expand Up @@ -113,7 +121,9 @@ def test_tokenizer_integration(self):
expected_encoding = {'input_ids': [[50257, 50362, 41762, 364, 357, 36234, 1900, 355, 12972, 13165, 354, 12, 35636, 364, 290, 12972, 13165, 354, 12, 5310, 13363, 12, 4835, 8, 3769, 2276, 12, 29983, 45619, 357, 13246, 51, 11, 402, 11571, 12, 17, 11, 5564, 13246, 38586, 11, 16276, 44, 11, 4307, 346, 33, 861, 11, 16276, 7934, 23029, 329, 12068, 15417, 28491, 357, 32572, 52, 8, 290, 12068, 15417, 16588, 357, 32572, 38, 8, 351, 625, 3933, 10, 2181, 13363, 4981, 287, 1802, 10, 8950, 290, 2769, 48817, 1799, 1022, 449, 897, 11, 9485, 15884, 354, 290, 309, 22854, 37535, 13, 50256], [50257, 50362, 13246, 51, 318, 3562, 284, 662, 12, 27432, 2769, 8406, 4154, 282, 24612, 422, 9642, 9608, 276, 2420, 416, 26913, 21143, 319, 1111, 1364, 290, 826, 4732, 287, 477, 11685, 13, 50256], [50257, 50362, 464, 2068, 7586, 21831, 18045, 625, 262, 16931, 3290, 13, 50256]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]} # fmt: skip

self.tokenizer_integration_test_util(
expected_encoding=expected_encoding, model_name="openai/whisper-tiny.en", padding=False
expected_encoding=expected_encoding,
model_name="openai/whisper-tiny.en",
padding=False,
)

def test_output_offsets(self):
Expand All @@ -137,7 +147,10 @@ def test_output_offsets(self):
" small, sharp blow high on his chest.<|endoftext|>"
),
"offsets": [
{"text": " of spectators, retrievality is not worth thinking about.", "timestamp": (0.0, 5.0)},
{
"text": " of spectators, retrievality is not worth thinking about.",
"timestamp": (0.0, 5.0),
},
{
"text": " His instant panic was followed by a small, sharp blow high on his chest.",
"timestamp": (5.0, 9.4),
Expand Down Expand Up @@ -204,11 +217,21 @@ def test_skip_special_tokens_skips_prompt_ids(self):
# fmt: on
expected_with_special_tokens = "<|startofprev|> Mr. Quilter<|startoftranscript|><|en|><|transcribe|><|notimestamps|> On the general principles of art, Mr. Quilter writes with equal lucidity.<|endoftext|>"
expected_without_special_tokens = " On the general principles of art, Mr. Quilter writes with equal lucidity."
self.assertEqual(tokenizer.decode(encoded_input, skip_special_tokens=False), expected_with_special_tokens)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove all these changes which shouldn't be applied (our line length is 120 and this is a formatting change unrelated to the PR)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amyeroberts All unrelated changes have been reverted. Now is it the proper time for merging the PR?

self.assertEqual(tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens)
self.assertEqual(rust_tokenizer.decode(encoded_input, skip_special_tokens=False), expected_with_special_tokens)
self.assertEqual(
rust_tokenizer.decode(encoded_input, skip_special_tokens=True), expected_without_special_tokens
tokenizer.decode(encoded_input, skip_special_tokens=False),
expected_with_special_tokens,
)
self.assertEqual(
tokenizer.decode(encoded_input, skip_special_tokens=True),
expected_without_special_tokens,
)
self.assertEqual(
rust_tokenizer.decode(encoded_input, skip_special_tokens=False),
expected_with_special_tokens,
)
self.assertEqual(
rust_tokenizer.decode(encoded_input, skip_special_tokens=True),
expected_without_special_tokens,
)

def test_skip_special_tokens_with_timestamps(self):
Expand Down Expand Up @@ -293,7 +316,13 @@ def test_combine_tokens_into_words(self):
# 'whatever "whatever" said someone, clever!?'
encoded_input = [1363, 7969, 503, 1363, 7969, 1, 848, 1580, 11, 13494, 7323]
expected_words = ["whatever", ' "whatever"', " said", " someone,", " clever!?"]
expected_tokens = [[1363, 7969], [503, 1363, 7969, 1], [848], [1580, 11], [13494, 7323]]
expected_tokens = [
[1363, 7969],
[503, 1363, 7969, 1],
[848],
[1580, 11],
[13494, 7323],
]
expected_indices = [[0, 1], [2, 3, 4, 5], [6], [7, 8], [9, 10]]
output = _combine_tokens_into_words(tokenizer, encoded_input)
self.assertEqual(expected_words, output[0])
Expand Down Expand Up @@ -321,7 +350,10 @@ def test_basic_normalizer(self):
self.assertEqual(decoded_output_normalize, expected_output_normalize)

decoded_output_diacritics = tokenizer.decode(
encoded_input, skip_special_tokens=True, basic_normalize=True, remove_diacritics=True
encoded_input,
skip_special_tokens=True,
basic_normalize=True,
remove_diacritics=True,
)
self.assertEqual(decoded_output_diacritics, expected_output_diacritics)

Expand All @@ -334,7 +366,10 @@ def test_basic_normalizer(self):
self.assertEqual(decoded_output_normalize, expected_output_normalize)

decoded_output_diacritics = rust_tokenizer.decode(
encoded_input, skip_special_tokens=True, basic_normalize=True, remove_diacritics=True
encoded_input,
skip_special_tokens=True,
basic_normalize=True,
remove_diacritics=True,
)
self.assertEqual(decoded_output_diacritics, expected_output_diacritics)

Expand All @@ -356,7 +391,10 @@ def test_decode_asr_with_word_level_timestamps(self):

tokenizer = WhisperTokenizer.from_pretrained("onnx-community/whisper-tiny.en_timestamped")
result = tokenizer._decode_asr(
model_outputs, return_timestamps="word", return_language=False, time_precision=0.02
model_outputs,
return_timestamps="word",
return_language=False,
time_precision=0.02,
)

EXPECTED_OUTPUT = (
Expand Down Expand Up @@ -574,3 +612,42 @@ def test_offset_decoding(self):

output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
self.assertEqual(output, [])

def test_convert_to_list_np(self):
test_list = [[1, 2, 3], [4, 5, 6]]

# Test with an already converted list
self.assertListEqual(WhisperTokenizer._convert_to_list(test_list), test_list)
self.assertListEqual(WhisperTokenizerFast._convert_to_list(test_list), test_list)

# Test with a numpy array
np_array = np.array(test_list)
self.assertListEqual(WhisperTokenizer._convert_to_list(np_array), test_list)
self.assertListEqual(WhisperTokenizerFast._convert_to_list(np_array), test_list)

@require_tf
def test_convert_to_list_tf(self):
import tensorflow as tf

test_list = [[1, 2, 3], [4, 5, 6]]
tf_tensor = tf.constant(test_list)
self.assertListEqual(WhisperTokenizer._convert_to_list(tf_tensor), test_list)
self.assertListEqual(WhisperTokenizerFast._convert_to_list(tf_tensor), test_list)

@require_flax
def test_convert_to_list_jax(self):
import jax.numpy as jnp

test_list = [[1, 2, 3], [4, 5, 6]]
jax_array = jnp.array(test_list)
self.assertListEqual(WhisperTokenizer._convert_to_list(jax_array), test_list)
self.assertListEqual(WhisperTokenizerFast._convert_to_list(jax_array), test_list)

@require_torch
def test_convert_to_list_pt(self):
import torch

test_list = [[1, 2, 3], [4, 5, 6]]
torch_tensor = torch.tensor(test_list)
self.assertListEqual(WhisperTokenizer._convert_to_list(torch_tensor), test_list)
self.assertListEqual(WhisperTokenizerFast._convert_to_list(torch_tensor), test_list)
Loading