diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 82dcba4fdb7b..b23e325e75e5 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -869,6 +869,8 @@ def _convert_to_list(token_ids): token_ids = token_ids.cpu().numpy() elif "tensorflow" in str(type(token_ids)): token_ids = token_ids.numpy() + elif "jaxlib" in str(type(token_ids)): + token_ids = token_ids.tolist() # now the token ids are either a numpy array, or a list of lists if isinstance(token_ids, np.ndarray): token_ids = token_ids.tolist() diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index 5019a9ebcda4..7227235b6406 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -602,6 +602,8 @@ def _convert_to_list(token_ids): token_ids = token_ids.cpu().numpy() elif "tensorflow" in str(type(token_ids)): token_ids = token_ids.numpy() + elif "jaxlib" in str(type(token_ids)): + token_ids = token_ids.tolist() # now the token ids are either a numpy array, or a list of lists if isinstance(token_ids, np.ndarray): token_ids = token_ids.tolist() diff --git a/tests/models/whisper/test_tokenization_whisper.py b/tests/models/whisper/test_tokenization_whisper.py index 5c653f1984f6..27b24448d5a2 100644 --- a/tests/models/whisper/test_tokenization_whisper.py +++ b/tests/models/whisper/test_tokenization_whisper.py @@ -18,7 +18,7 @@ from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence -from transformers.testing_utils import slow +from transformers.testing_utils import require_flax, require_tf, require_torch, slow from ...test_tokenization_common import TokenizerTesterMixin @@ -574,3 +574,42 @@ def test_offset_decoding(self): output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"] self.assertEqual(output, []) + + def test_convert_to_list_np(self): + test_list = [[1, 2, 3], [4, 5, 6]] + + # Test with an already converted list + self.assertListEqual(WhisperTokenizer._convert_to_list(test_list), test_list) + self.assertListEqual(WhisperTokenizerFast._convert_to_list(test_list), test_list) + + # Test with a numpy array + np_array = np.array(test_list) + self.assertListEqual(WhisperTokenizer._convert_to_list(np_array), test_list) + self.assertListEqual(WhisperTokenizerFast._convert_to_list(np_array), test_list) + + @require_tf + def test_convert_to_list_tf(self): + import tensorflow as tf + + test_list = [[1, 2, 3], [4, 5, 6]] + tf_tensor = tf.constant(test_list) + self.assertListEqual(WhisperTokenizer._convert_to_list(tf_tensor), test_list) + self.assertListEqual(WhisperTokenizerFast._convert_to_list(tf_tensor), test_list) + + @require_flax + def test_convert_to_list_jax(self): + import jax.numpy as jnp + + test_list = [[1, 2, 3], [4, 5, 6]] + jax_array = jnp.array(test_list) + self.assertListEqual(WhisperTokenizer._convert_to_list(jax_array), test_list) + self.assertListEqual(WhisperTokenizerFast._convert_to_list(jax_array), test_list) + + @require_torch + def test_convert_to_list_pt(self): + import torch + + test_list = [[1, 2, 3], [4, 5, 6]] + torch_tensor = torch.tensor(test_list) + self.assertListEqual(WhisperTokenizer._convert_to_list(torch_tensor), test_list) + self.assertListEqual(WhisperTokenizerFast._convert_to_list(torch_tensor), test_list)