Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix issue with logit processor during beam search in Flax #29636

Merged
merged 1 commit into from
Mar 21, 2024

Conversation

giganttheo
Copy link
Contributor

What does this PR do?

Fixes #29635

@gante
Copy link
Member

gante commented Mar 18, 2024

Hi @giganttheo 👋 Thank you for opening the PR!

The fix looks reasonable to me. However, if the fix is indeed correct, I wonder how our code could be running correctly before 🤔

I have three small requests:

  1. Can you share your jax and flax versions?
  2. Can you share a small reproducer that was broken before this PR, but fixed after it?
  3. Can you run the following tests locally and confirm that they pass after these changes: [these use beam search + flax]
    a. RUN_SLOW=1 py.test tests/models/bart/test_modeling_flax_bart.py -vv
    b. RUN_SLOW=1 py.test tests/models/t5/test_modeling_flax_t5.py -vv
    c. RUN_SLOW=1 py.test tests/models/whisper/test_modeling_flax_whisper.py -vv

@giganttheo
Copy link
Contributor Author

giganttheo commented Mar 18, 2024

I think that not many people have encountered this issue before, since most flax logits processors do not really use the input_ids argument, with the exception of FlaxWhisperTimeStampLogitsProcessor.
I only found out about it while working on a no n-gram repeat logits processor for Flax (cf #29677 ), which uses this argument.

About your requests:

  1. I am working with jax==0.4.13 and flax==0.7.2
  2. As most codes don't really use the input_ids argument, here is a code that uses the FlaxNoRepeatNGramLogitsProcessor from Adding FlaxNoRepeatNGramLogitsProcessor #29677 :
from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-small", dtype="bfloat16")

input_text = "translate English to French: hello how are you? hello how are you? hello how are you? hello how are you? hello how are you?"

input_ids = tokenizer(input_text, return_tensors="np").input_ids

decoder_start_token_id=model.config.decoder_start_token_id

outputs = model.generate(input_ids=input_ids, num_beams=2, decoder_start_token_id=decoder_start_token_id, no_repeat_ngram_size=2)
outputs.sequences, tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)

without the change, it gives:

(Array([[    0, 21845,     6,  1670,     3,  6738,    18,  3249,    58,
         21845,     6,  1670,     3,  6738,    18,  3249,    58,     1,
             0,     0]], dtype=int32),
 ['Bonjour, comment êtes-vous? Bonjour, comment êtes-vous?'])

For instance the 2-gram (1670, 3) is repeated, that should not be possible with the n-gram blocking

with the change, it prompts:

(Array([[    0, 21845,     6,  1670,     3,  6738,    18,  3249,    58,
         21845,     3,    15,    17,  1670,    58,     1,     0,     0,
             0,     0]], dtype=int32),
 ['Bonjour, comment êtes-vous? Bonjour et comment?'])

there is no 2-gram repetition

For reference, with torch and 2-gram blocking, the model output is:

tensor([    0, 21845,     6,  1670,     3,  6738,    18,  3249,    58, 21845,
            3,    15,    17,  1670,   327,    58,     1])
Bonjour, comment êtes-vous? Bonjour et comment vous?
  1. Most tests pass, and for the ones that fail or are skipped, I am not sure if it comes from the library versions, deprecated code or if the error comes from the modification. Here are the logs to the local tests if it is helpful:

RUN_SLOW=1 pytest tests/models/bart/test_modeling_flax_bart.py -vv gives 52 passed:

Show
============================= test session starts ==============================
platform linux -- Python 3.9.18, pytest-8.1.1, pluggy-1.4.0 -- /home/gigant/miniconda3/envs/transformers-dev/bin/python
cachedir: .pytest_cache
rootdir: /home/gigant/Documents/transformers_fix/transformers
configfile: pyproject.toml
collected 52 items                                                             

tests/models/bart/test_modeling_flax_bart.py::BartHeadTests::test_lm_forward PASSED [  1%]
tests/models/bart/test_modeling_flax_bart.py::BartHeadTests::test_lm_uneven_forward PASSED [  3%]
tests/models/bart/test_modeling_flax_bart.py::BartHeadTests::test_question_answering_forward PASSED [  5%]
tests/models/bart/test_modeling_flax_bart.py::BartHeadTests::test_sequence_classification_forward PASSED [  7%]
tests/models/bart/test_modeling_flax_bart.py::BartHeadTests::test_shift_tokens_right PASSED [  9%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_attention_outputs <- tests/test_modeling_flax_common.py PASSED [ 11%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_beam_search_generate <- tests/generation/test_flax_utils.py PASSED [ 13%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_beam_search_generate_attn_mask <- tests/generation/test_flax_utils.py PASSED [ 15%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_beam_search_generate_logits_warper <- tests/generation/test_flax_utils.py PASSED [ 17%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_beam_search_generate_num_return_sequences <- tests/generation/test_flax_utils.py PASSED [ 19%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_checkpoint_sharding_from_hub <- tests/test_modeling_flax_common.py PASSED [ 21%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_checkpoint_sharding_local <- tests/test_modeling_flax_common.py PASSED [ 23%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_cnn_summarization_same_as_fairseq PASSED [ 25%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_decode PASSED [ 26%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_default_params_dtype <- tests/test_modeling_flax_common.py PASSED [ 28%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_encode PASSED [ 30%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_equivalence_flax_to_pt <- tests/test_modeling_flax_common.py PASSED [ 32%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_equivalence_pt_to_flax <- tests/test_modeling_flax_common.py PASSED [ 34%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_forward_signature <- tests/test_modeling_flax_common.py PASSED [ 36%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_from_pretrained_save_pretrained <- tests/test_modeling_flax_common.py PASSED [ 38%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_from_pretrained_with_no_automatic_init <- tests/test_modeling_flax_common.py PASSED [ 40%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_from_sharded_pt <- tests/test_modeling_flax_common.py PASSED [ 42%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_gradient_checkpointing <- tests/test_modeling_flax_common.py PASSED [ 44%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_greedy_generate <- tests/generation/test_flax_utils.py PASSED [ 46%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_greedy_generate_attn_mask <- tests/generation/test_flax_utils.py PASSED [ 48%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_greedy_generate_logits_warper <- tests/generation/test_flax_utils.py PASSED [ 50%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_greedy_generate_pt_fx <- tests/generation/test_flax_utils.py PASSED [ 51%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_headmasking <- tests/test_modeling_flax_common.py PASSED [ 53%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_hidden_states_output <- tests/test_modeling_flax_common.py PASSED [ 55%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_jit_compilation <- tests/test_modeling_flax_common.py PASSED [ 57%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_load_with_mismatched_shapes <- tests/test_modeling_flax_common.py PASSED [ 59%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_model_from_pretrained PASSED [ 61%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_model_main_input_name <- tests/test_modeling_flax_common.py PASSED [ 63%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_model_outputs_equivalence <- tests/test_modeling_flax_common.py PASSED [ 65%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_naming_convention <- tests/test_modeling_flax_common.py PASSED [ 67%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_no_automatic_init <- tests/test_modeling_flax_common.py PASSED [ 69%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_sample_generate <- tests/generation/test_flax_utils.py PASSED [ 71%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_sample_generate_attn_mask <- tests/generation/test_flax_utils.py PASSED [ 73%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_sample_generate_logits_warper <- tests/generation/test_flax_utils.py PASSED [ 75%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_save_load_bf16_to_base_pt <- tests/test_modeling_flax_common.py PASSED [ 76%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_save_load_from_base <- tests/test_modeling_flax_common.py PASSED [ 78%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_save_load_from_base_pt <- tests/test_modeling_flax_common.py PASSED [ 80%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_save_load_in_bf16 <- tests/test_modeling_flax_common.py PASSED [ 82%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_save_load_in_fp16 <- tests/test_modeling_flax_common.py PASSED [ 84%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_save_load_to_base <- tests/test_modeling_flax_common.py PASSED [ 86%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_save_load_to_base_pt <- tests/test_modeling_flax_common.py PASSED [ 88%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_summarization_fast PASSED [ 90%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_to_bf16 <- tests/test_modeling_flax_common.py PASSED [ 92%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_to_fp16 <- tests/test_modeling_flax_common.py PASSED [ 94%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_to_fp32 <- tests/test_modeling_flax_common.py PASSED [ 96%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_use_cache_forward PASSED [ 98%]
tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_use_cache_forward_with_attn_mask PASSED [100%]

=============================== warnings summary ===============================
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/_pytest/config/__init__.py:1439
  /home/gigant/miniconda3/envs/transformers-dev/lib/python3.9/site-packages/_pytest/config/__init__.py:1439: PytestConfigWarning: Unknown config option: doctest_glob
  
    self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")

tests/models/bart/test_modeling_flax_bart.py: 371 warnings
  /home/gigant/Documents/transformers_fix/transformers/tests/test_modeling_flax_common.py:795: DeprecationWarning: Please use assertEqual instead.
    self.assertEquals(type_, jnp.float32, msg=f"param {name} is not initialized in fp32.")

tests/models/bart/test_modeling_flax_bart.py::FlaxBartModelTest::test_equivalence_flax_to_pt
  /home/gigant/miniconda3/envs/transformers-dev/lib/python3.9/site-packages/transformers/modeling_flax_pytorch_utils.py:460: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
    pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================= 52 passed, 373 warnings in 166.31s (0:02:46) =================

RUN_SLOW=1 pytest tests/models/t5/test_modeling_flax_t5.py -vv gives 80 passed, 6 skipped:

Show
======================================== test session starts ========================================
platform linux -- Python 3.9.18, pytest-8.1.1, pluggy-1.4.0 -- /home/gigant/miniconda3/envs/transformers-dev/bin/python
cachedir: .pytest_cache
rootdir: /home/gigant/Documents/transformers_fix/transformers
configfile: pyproject.toml
collected 86 items                                                                                  

tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_attention_outputs <- tests/test_modeling_flax_common.py PASSED [  1%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_beam_search_generate <- tests/generation/test_flax_utils.py PASSED [  2%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_beam_search_generate_attn_mask <- tests/generation/test_flax_utils.py PASSED [  3%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_beam_search_generate_logits_warper <- tests/generation/test_flax_utils.py PASSED [  4%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_beam_search_generate_num_return_sequences <- tests/generation/test_flax_utils.py PASSED [  5%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_checkpoint_sharding_from_hub <- tests/test_modeling_flax_common.py PASSED [  6%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_checkpoint_sharding_local <- tests/test_modeling_flax_common.py PASSED [  8%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_config PASSED                 [  9%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_decode PASSED                 [ 10%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_default_params_dtype <- tests/test_modeling_flax_common.py PASSED [ 11%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_encode PASSED                 [ 12%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_equivalence_flax_to_pt <- tests/test_modeling_flax_common.py PASSED [ 13%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_equivalence_pt_to_flax <- tests/test_modeling_flax_common.py PASSED [ 15%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_forward_signature <- tests/test_modeling_flax_common.py PASSED [ 16%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_from_pretrained_save_pretrained <- tests/test_modeling_flax_common.py PASSED [ 17%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_from_pretrained_with_no_automatic_init <- tests/test_modeling_flax_common.py PASSED [ 18%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_from_sharded_pt <- tests/test_modeling_flax_common.py PASSED [ 19%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_gradient_checkpointing <- tests/test_modeling_flax_common.py PASSED [ 20%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_greedy_generate <- tests/generation/test_flax_utils.py PASSED [ 22%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_greedy_generate_attn_mask <- tests/generation/test_flax_utils.py PASSED [ 23%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_greedy_generate_logits_warper <- tests/generation/test_flax_utils.py PASSED [ 24%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_greedy_generate_pt_fx <- tests/generation/test_flax_utils.py PASSED [ 25%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_headmasking <- tests/test_modeling_flax_common.py PASSED [ 26%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_hidden_states_output <- tests/test_modeling_flax_common.py PASSED [ 27%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_jit_compilation <- tests/test_modeling_flax_common.py PASSED [ 29%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_load_with_mismatched_shapes <- tests/test_modeling_flax_common.py PASSED [ 30%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_model PASSED                  [ 31%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_model_main_input_name <- tests/test_modeling_flax_common.py PASSED [ 32%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_model_outputs_equivalence <- tests/test_modeling_flax_common.py PASSED [ 33%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_model_v1_1 PASSED             [ 34%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_naming_convention <- tests/test_modeling_flax_common.py PASSED [ 36%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_no_automatic_init <- tests/test_modeling_flax_common.py PASSED [ 37%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_sample_generate <- tests/generation/test_flax_utils.py PASSED [ 38%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_sample_generate_attn_mask <- tests/generation/test_flax_utils.py PASSED [ 39%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_sample_generate_logits_warper <- tests/generation/test_flax_utils.py PASSED [ 40%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_save_load_bf16_to_base_pt PASSED [ 41%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_save_load_from_base PASSED    [ 43%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_save_load_from_base_pt PASSED [ 44%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_save_load_in_bf16 <- tests/test_modeling_flax_common.py PASSED [ 45%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_save_load_in_fp16 <- tests/test_modeling_flax_common.py PASSED [ 46%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_save_load_to_base PASSED      [ 47%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_save_load_to_base_pt PASSED   [ 48%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_shift_right PASSED            [ 50%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_to_bf16 <- tests/test_modeling_flax_common.py PASSED [ 51%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_to_fp16 <- tests/test_modeling_flax_common.py PASSED [ 52%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_to_fp32 <- tests/test_modeling_flax_common.py PASSED [ 53%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_use_cache_forward_with_attn_mask PASSED [ 54%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_attention_outputs <- tests/test_modeling_flax_common.py PASSED [ 55%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_checkpoint_sharding_from_hub <- tests/test_modeling_flax_common.py PASSED [ 56%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_checkpoint_sharding_local <- tests/test_modeling_flax_common.py PASSED [ 58%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_config PASSED      [ 59%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_default_params_dtype <- tests/test_modeling_flax_common.py PASSED [ 60%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_encode PASSED      [ 61%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_equivalence_flax_to_pt <- tests/test_modeling_flax_common.py PASSED [ 62%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_equivalence_pt_to_flax <- tests/test_modeling_flax_common.py PASSED [ 63%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_forward_signature <- tests/test_modeling_flax_common.py PASSED [ 65%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_from_pretrained_save_pretrained <- tests/test_modeling_flax_common.py PASSED [ 66%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_from_pretrained_with_no_automatic_init <- tests/test_modeling_flax_common.py PASSED [ 67%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_from_sharded_pt <- tests/test_modeling_flax_common.py PASSED [ 68%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_gradient_checkpointing <- tests/test_modeling_flax_common.py PASSED [ 69%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_headmasking <- tests/test_modeling_flax_common.py PASSED [ 70%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_hidden_states_output <- tests/test_modeling_flax_common.py PASSED [ 72%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_jit_compilation <- tests/test_modeling_flax_common.py PASSED [ 73%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_load_with_mismatched_shapes <- tests/test_modeling_flax_common.py PASSED [ 74%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_model PASSED       [ 75%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_model_main_input_name <- tests/test_modeling_flax_common.py PASSED [ 76%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_model_outputs_equivalence <- tests/test_modeling_flax_common.py PASSED [ 77%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_model_v1_1 PASSED  [ 79%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_naming_convention <- tests/test_modeling_flax_common.py PASSED [ 80%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_no_automatic_init <- tests/test_modeling_flax_common.py PASSED [ 81%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_save_load_bf16_to_base_pt PASSED [ 82%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_save_load_from_base PASSED [ 83%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_save_load_from_base_pt PASSED [ 84%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_save_load_in_bf16 <- tests/test_modeling_flax_common.py PASSED [ 86%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_save_load_in_fp16 <- tests/test_modeling_flax_common.py PASSED [ 87%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_save_load_to_base PASSED [ 88%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_save_load_to_base_pt PASSED [ 89%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_to_bf16 <- tests/test_modeling_flax_common.py PASSED [ 90%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_to_fp16 <- tests/test_modeling_flax_common.py PASSED [ 91%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5EncoderOnlyModelTest::test_to_fp32 <- tests/test_modeling_flax_common.py PASSED [ 93%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelIntegrationTests::test_small_byt5_integration_test SKIPPED [ 94%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelIntegrationTests::test_small_generation SKIPPED [ 95%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelIntegrationTests::test_small_generation_bfloat16 SKIPPED [ 96%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelIntegrationTests::test_small_integration_test SKIPPED [ 97%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelIntegrationTests::test_small_v1_1_integration_test SKIPPED [ 98%]
tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelIntegrationTests::test_summarization SKIPPED [100%]

========================================= warnings summary ==========================================
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/_pytest/config/__init__.py:1439
  /home/gigant/miniconda3/envs/transformers-dev/lib/python3.9/site-packages/_pytest/config/__init__.py:1439: PytestConfigWarning: Unknown config option: doctest_glob
  
    self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")

tests/models/t5/test_modeling_flax_t5.py: 113 warnings
  /home/gigant/Documents/transformers_fix/transformers/tests/test_modeling_flax_common.py:795: DeprecationWarning: Please use assertEqual instead.
    self.assertEquals(type_, jnp.float32, msg=f"param {name} is not initialized in fp32.")

tests/models/t5/test_modeling_flax_t5.py::FlaxT5ModelTest::test_equivalence_flax_to_pt
  /home/gigant/miniconda3/envs/transformers-dev/lib/python3.9/site-packages/transformers/modeling_flax_pytorch_utils.py:460: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
    pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
====================== 80 passed, 6 skipped, 115 warnings in 65.67s (0:01:05) =======================

RUN_SLOW=1 pytest tests/models/whisper/test_modeling_flax_whisper.py -vv gives 6 failed, 68 passed:

Show
======================================== test session starts ========================================
platform linux -- Python 3.9.18, pytest-8.1.1, pluggy-1.4.0 -- /home/gigant/miniconda3/envs/transformers-dev/bin/python
cachedir: .pytest_cache
rootdir: /home/gigant/Documents/transformers_fix/transformers
configfile: pyproject.toml
collected 74 items                                                                                  

tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_attention_outputs <- tests/test_modeling_flax_common.py PASSED [  1%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_checkpoint_sharding_from_hub <- tests/test_modeling_flax_common.py PASSED [  2%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_checkpoint_sharding_local <- tests/test_modeling_flax_common.py PASSED [  4%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_config PASSED  [  5%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_default_params_dtype <- tests/test_modeling_flax_common.py PASSED [  6%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_encoder_sinusoidal_embed_positions PASSED [  8%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_equivalence_flax_to_pt <- tests/test_modeling_flax_common.py PASSED [  9%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_equivalence_pt_to_flax <- tests/test_modeling_flax_common.py PASSED [ 10%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_forward_signature PASSED [ 12%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_from_pretrained_save_pretrained <- tests/test_modeling_flax_common.py PASSED [ 13%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_from_pretrained_with_no_automatic_init <- tests/test_modeling_flax_common.py PASSED [ 14%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_from_sharded_pt <- tests/test_modeling_flax_common.py PASSED [ 16%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_gradient_checkpointing <- tests/test_modeling_flax_common.py PASSED [ 17%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_headmasking <- tests/test_modeling_flax_common.py PASSED [ 18%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_hidden_states_output <- tests/test_modeling_flax_common.py PASSED [ 20%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_jit_compilation PASSED [ 21%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_load_with_mismatched_shapes <- tests/test_modeling_flax_common.py PASSED [ 22%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_model_main_input_name <- tests/test_modeling_flax_common.py PASSED [ 24%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_model_outputs_equivalence <- tests/test_modeling_flax_common.py PASSED [ 25%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_naming_convention <- tests/test_modeling_flax_common.py PASSED [ 27%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_no_automatic_init <- tests/test_modeling_flax_common.py PASSED [ 28%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_save_load_bf16_to_base_pt PASSED [ 29%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_save_load_from_base PASSED [ 31%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_save_load_from_base_pt PASSED [ 32%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_save_load_in_bf16 <- tests/test_modeling_flax_common.py PASSED [ 33%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_save_load_in_fp16 <- tests/test_modeling_flax_common.py PASSED [ 35%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_save_load_to_base PASSED [ 36%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_save_load_to_base_pt PASSED [ 37%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_to_bf16 <- tests/test_modeling_flax_common.py PASSED [ 39%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_to_fp16 <- tests/test_modeling_flax_common.py PASSED [ 40%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_to_fp32 <- tests/test_modeling_flax_common.py PASSED [ 41%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_batched_generation FAILED [ 43%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_generation PASSED [ 44%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_generation_multilingual FAILED [ 45%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_logits_librispeech FAILED [ 47%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_small_en_logits_librispeech FAILED [ 48%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_tiny_en_batched_generation PASSED [ 50%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_tiny_en_generation FAILED [ 51%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_tiny_generation PASSED [ 52%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_tiny_logits_librispeech PASSED [ 54%]
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_tiny_timestamp_generation FAILED [ 55%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_attention_outputs <- tests/test_modeling_flax_common.py PASSED [ 56%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_checkpoint_sharding_from_hub <- tests/test_modeling_flax_common.py PASSED [ 58%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_checkpoint_sharding_local <- tests/test_modeling_flax_common.py PASSED [ 59%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_config PASSED [ 60%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_default_params_dtype <- tests/test_modeling_flax_common.py PASSED [ 62%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_equivalence_flax_to_pt <- tests/test_modeling_flax_common.py PASSED [ 63%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_equivalence_pt_to_flax <- tests/test_modeling_flax_common.py PASSED [ 64%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_forward_signature PASSED [ 66%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_from_pretrained_save_pretrained <- tests/test_modeling_flax_common.py PASSED [ 67%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_from_pretrained_with_no_automatic_init <- tests/test_modeling_flax_common.py PASSED [ 68%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_from_sharded_pt <- tests/test_modeling_flax_common.py PASSED [ 70%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_gradient_checkpointing <- tests/test_modeling_flax_common.py PASSED [ 71%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_headmasking <- tests/test_modeling_flax_common.py PASSED [ 72%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_hidden_states_output <- tests/test_modeling_flax_common.py PASSED [ 74%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_inputs_embeds PASSED [ 75%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_jit_compilation PASSED [ 77%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_load_with_mismatched_shapes <- tests/test_modeling_flax_common.py PASSED [ 78%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_model_common_attributes PASSED [ 79%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_model_main_input_name <- tests/test_modeling_flax_common.py PASSED [ 81%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_model_outputs_equivalence <- tests/test_modeling_flax_common.py PASSED [ 82%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_naming_convention <- tests/test_modeling_flax_common.py PASSED [ 83%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_no_automatic_init <- tests/test_modeling_flax_common.py PASSED [ 85%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_resize_tokens_embeddings PASSED [ 86%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_save_load_bf16_to_base_pt PASSED [ 87%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_save_load_from_base PASSED [ 89%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_save_load_from_base_pt PASSED [ 90%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_save_load_in_bf16 <- tests/test_modeling_flax_common.py PASSED [ 91%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_save_load_in_fp16 <- tests/test_modeling_flax_common.py PASSED [ 93%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_save_load_to_base PASSED [ 94%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_save_load_to_base_pt PASSED [ 95%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_to_bf16 <- tests/test_modeling_flax_common.py PASSED [ 97%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_to_fp16 <- tests/test_modeling_flax_common.py PASSED [ 98%]
tests/models/whisper/test_modeling_flax_whisper.py::WhisperEncoderModelTest::test_to_fp32 <- tests/test_modeling_flax_common.py PASSED [100%]

============================================= FAILURES ==============================================
___________________ FlaxWhisperModelIntegrationTest.test_large_batched_generation ___________________

self = <tests.models.whisper.test_modeling_flax_whisper.FlaxWhisperModelIntegrationTest testMethod=test_large_batched_generation>

    def test_large_batched_generation(self):
        processor = WhisperProcessor.from_pretrained("openai/whisper-large")
        model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-large", from_pt=True)
    
        input_speech = self._load_datasamples(4)
        input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="np").input_features
        generated_ids = model.generate(input_features, max_length=20).sequences
    
        # fmt: off
        EXPECTED_LOGITS = np.array(
            [
                [50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281],
                [50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257],
                [50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256],
                [50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11]
            ]
        )
        # fmt: on
    
>       self.assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS))
E       AssertionError: False is not true

tests/models/whisper/test_modeling_flax_whisper.py:613: AssertionError
--------------------------------------- Captured stderr call ----------------------------------------
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
________________ FlaxWhisperModelIntegrationTest.test_large_generation_multilingual _________________

self = <fsspec.implementations.http.HTTPFileSystem object at 0x7f4b0227d8e0>
url = 'https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com/cv-corpus-6.1-2020-12-11/ja.tar.gz'
kwargs = {}, info = {}, session = <aiohttp.client.ClientSession object at 0x7f4b028e0250>
policy = 'get'

    async def _info(self, url, **kwargs):
        """Get info of URL
    
        Tries to access location via HEAD, and then GET methods, but does
        not fetch the data.
    
        It is possible that the server does not supply any size information, in
        which case size will be given as None (and certain operations on the
        corresponding file will not work).
        """
        info = {}
        session = await self.set_session()
    
        for policy in ["head", "get"]:
            try:
                info.update(
>                   await _file_info(
                        self.encode_url(url),
                        size_policy=policy,
                        session=session,
                        **self.kwargs,
                        **kwargs,
                    )
                )

../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/fsspec/implementations/http.py:419: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/fsspec/implementations/http.py:832: in _file_info
    r.raise_for_status()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <ClientResponse(https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com/cv-corpus-6.1-202...e': 'application/xml', 'Transfer-Encoding': 'chunked', 'Date': 'Mon, 18 Mar 2024 18:09:08 GMT', 'Server': 'AmazonS3')>


    def raise_for_status(self) -> None:
        if not self.ok:
            # reason should always be not None for a started response
            assert self.reason is not None
            self.release()
>           raise ClientResponseError(
                self.request_info,
                self.history,
                status=self.status,
                message=self.reason,
                headers=self.headers,
            )
E           aiohttp.client_exceptions.ClientResponseError: 403, message='Forbidden', url=URL('https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com/cv-corpus-6.1-2020-12-11/ja.tar.gz')

../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/aiohttp/client_reqrep.py:1060: ClientResponseError

The above exception was the direct cause of the following exception:

self = <tests.models.whisper.test_modeling_flax_whisper.FlaxWhisperModelIntegrationTest testMethod=test_large_generation_multilingual>

    def test_large_generation_multilingual(self):
        processor = WhisperProcessor.from_pretrained("openai/whisper-large")
        model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-large", from_pt=True)
    
        ds = load_dataset("common_voice", "ja", split="test", streaming=True)
        ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000))
>       input_speech = next(iter(ds))["audio"]["array"]

tests/models/whisper/test_modeling_flax_whisper.py:566: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/datasets/iterable_dataset.py:1388: in __iter__
    for key, example in ex_iterable:
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/datasets/iterable_dataset.py:234: in __iter__
    yield from self.generate_examples_fn(**self.kwargs)
../../../.cache/huggingface/modules/datasets_modules/datasets/common_voice/220833898d6a60c50f621126e51fb22eb2dfe5244392c70dccd8e6e2f055f4bf/common_voice.py:774: in _generate_examples
    for path, f in archive_iterator:
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/datasets/download/streaming_download_manager.py:869: in __iter__
    yield from self.generator(*self.args, **self.kwargs)
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/datasets/download/streaming_download_manager.py:922: in _iter_from_urlpath
    with xopen(urlpath, "rb", download_config=download_config, block_size=0) as f:
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/datasets/download/streaming_download_manager.py:512: in xopen
    file_obj = fsspec.open(file, mode=mode, *args, **kwargs).open()
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/fsspec/core.py:135: in open
    return self.__enter__()
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/fsspec/core.py:103: in __enter__
    f = self.fs.open(self.path, mode=mode)
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/fsspec/spec.py:1293: in open
    f = self._open(
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/fsspec/implementations/http.py:358: in _open
    size = size or self.info(path, **kwargs)["size"]
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/fsspec/asyn.py:118: in wrapper
    return sync(self.loop, func, *args, **kwargs)
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/fsspec/asyn.py:103: in sync
    raise return_result
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/fsspec/asyn.py:56: in _runner
    result[0] = await coro
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <fsspec.implementations.http.HTTPFileSystem object at 0x7f4b0227d8e0>
url = 'https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com/cv-corpus-6.1-2020-12-11/ja.tar.gz'
kwargs = {}, info = {}, session = <aiohttp.client.ClientSession object at 0x7f4b028e0250>
policy = 'get'

    async def _info(self, url, **kwargs):
        """Get info of URL
    
        Tries to access location via HEAD, and then GET methods, but does
        not fetch the data.
    
        It is possible that the server does not supply any size information, in
        which case size will be given as None (and certain operations on the
        corresponding file will not work).
        """
        info = {}
        session = await self.set_session()
    
        for policy in ["head", "get"]:
            try:
                info.update(
                    await _file_info(
                        self.encode_url(url),
                        size_policy=policy,
                        session=session,
                        **self.kwargs,
                        **kwargs,
                    )
                )
                if info.get("size") is not None:
                    break
            except Exception as exc:
                if policy == "get":
                    # If get failed, then raise a FileNotFoundError
>                   raise FileNotFoundError(url) from exc
E                   FileNotFoundError: https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com/cv-corpus-6.1-2020-12-11/ja.tar.gz

../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/fsspec/implementations/http.py:432: FileNotFoundError
--------------------------------------- Captured stderr call ----------------------------------------
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
___________________ FlaxWhisperModelIntegrationTest.test_large_logits_librispeech ___________________

self = <tests.models.whisper.test_modeling_flax_whisper.FlaxWhisperModelIntegrationTest testMethod=test_large_logits_librispeech>

    def test_large_logits_librispeech(self):
        model = FlaxWhisperModel.from_pretrained("openai/whisper-large", from_pt=True)
        input_speech = self._load_datasamples(1)
        processor = WhisperProcessor.from_pretrained("openai/whisper-large")
        processed_inputs = processor(
            audio=input_speech, text="This part of the speech", add_special_tokens=False, return_tensors="np"
        )
        input_features = processed_inputs.input_features
        decoder_input_ids = processed_inputs.labels
    
        logits = model(
            input_features,
            decoder_input_ids=decoder_input_ids,
            output_hidden_states=False,
            output_attentions=False,
            return_dict=False,
        )
    
>       logits = logits[0] @ model.params["model"]["decoder"]["embed_tokens"]["embedding"].T
E       KeyError: 'model'

tests/models/whisper/test_modeling_flax_whisper.py:492: KeyError
--------------------------------------- Captured stderr call ----------------------------------------
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
_________________ FlaxWhisperModelIntegrationTest.test_small_en_logits_librispeech __________________

self = <tests.models.whisper.test_modeling_flax_whisper.FlaxWhisperModelIntegrationTest testMethod=test_small_en_logits_librispeech>

    def test_small_en_logits_librispeech(self):
        model = FlaxWhisperModel.from_pretrained("openai/whisper-small.en", from_pt=True)
        input_speech = self._load_datasamples(1)
        feature_extractor = WhisperFeatureExtractor()
        input_features = feature_extractor(input_speech, return_tensors="np").input_features
    
>       logits = model(
            input_features,
            decoder_input_ids=np.array([model.config.decoder_start_token_id]),
            output_hidden_states=False,
            output_attentions=False,
            return_dict=False,
        )

tests/models/whisper/test_modeling_flax_whisper.py:451: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <transformers.models.whisper.modeling_flax_whisper.FlaxWhisperModel object at 0x7f4b2e3ee0a0>
input_features = array([[[ 1.1933082e-01, -9.4576120e-02, -1.0977852e-01, ...,
         -8.0602670e-01, -8.0602670e-01, -8.0602670e-01]...70e-01, -8.0602670e-01, -8.0602670e-01, ...,
         -8.0602670e-01, -8.0602670e-01, -8.0602670e-01]]], dtype=float32)
decoder_input_ids = array([50257]), attention_mask = None, decoder_attention_mask = None
position_ids = None, decoder_position_ids = None, output_attentions = False
output_hidden_states = False, return_dict = False, train = False, params = None, dropout_rng = None

    @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
    def __call__(
        self,
        input_features: jnp.ndarray,
        decoder_input_ids: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        position_ids: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict
    
        # prepare decoder inputs
        if decoder_position_ids is None:
            if decoder_attention_mask is not None:
                decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
            else:
>               batch_size, sequence_length = decoder_input_ids.shape
E               ValueError: not enough values to unpack (expected 2, got 1)

../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/transformers/models/whisper/modeling_flax_whisper.py:1161: ValueError
--------------------------------------- Captured stderr call ----------------------------------------
It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
______________________ FlaxWhisperModelIntegrationTest.test_tiny_en_generation ______________________

self = <tests.models.whisper.test_modeling_flax_whisper.FlaxWhisperModelIntegrationTest testMethod=test_tiny_en_generation>

    def test_tiny_en_generation(self):
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
        model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
        model.config.decoder_start_token_id = 50257
    
        input_speech = self._load_datasamples(1)
        input_features = processor.feature_extractor(
            raw_speech=input_speech, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="jax"
        ).input_features
    
        generated_ids = model.generate(input_features, num_beams=5, max_length=20).sequences
        transcript = processor.tokenizer.decode(generated_ids[0])
    
        EXPECTED_TRANSCRIPT = (
            "<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle"
            " classes and we are glad to"
        )
>       self.assertEqual(transcript, EXPECTED_TRANSCRIPT)
E       AssertionError: '<|st[14 chars]t|><|notimestamps|> Mr. Quilter is the apostle[84 chars]xt|>' != '<|st[14 chars]t|><|en|><|transcribe|><|notimestamps|> Mr. Qu[57 chars]d to'
E       - <|startoftranscript|><|notimestamps|> Mr. Quilter is the apostle of the middle classes,<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
E       + <|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle classes and we are glad to

tests/models/whisper/test_modeling_flax_whisper.py:523: AssertionError
__________________ FlaxWhisperModelIntegrationTest.test_tiny_timestamp_generation ___________________

self = <tests.models.whisper.test_modeling_flax_whisper.FlaxWhisperModelIntegrationTest testMethod=test_tiny_timestamp_generation>

    @slow
    def test_tiny_timestamp_generation(self):
        processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
        model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
    
        input_speech = np.concatenate(self._load_datasamples(4))
        input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="jax").input_features
    
        generate_fn = jax.jit(functools.partial(model.generate, max_length=448, return_timestamps=True))
    
        generated_ids = generate_fn(input_features)
    
        EXPECTED_OUTPUT = np.array([50258, 50259, 50359, 50364, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50692, 50692, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50926, 50926, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 51208, 51208, 949, 505, 11, 14138, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51552, 51552, 634, 575, 12525, 22618, 1968, 6144, 35617, 7354, 1292, 6, 589, 307, 534, 10281, 934, 439, 11, 293, 51836, 51836, 50257])  # fmt: skip
    
>       self.assertTrue(np.allclose(generated_ids, EXPECTED_OUTPUT))

tests/models/whisper/test_modeling_flax_whisper.py:675: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/numpy/core/numeric.py:2241: in allclose
    res = all(isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

a = FlaxGreedySearchOutput(sequences=Array([[50258, 50259, 50359, 50364,  2221,    13,  2326,   388,   391,
          307,...257, 50257, 50257, 50257, 50257, 50257, 50257,
        50257, 50257, 50257, 50257, 50257, 50257, 50257]], dtype=int32))
b = array([50258, 50259, 50359, 50364,  2221,    13,  2326,   388,   391,
         307,   264, 50244,   295,   264,  2808,... 6144, 35617,  7354,  1292,     6,   589,   307,   534, 10281,
         934,   439,    11,   293, 51836, 51836, 50257])
rtol = 1e-05, atol = 1e-08, equal_nan = False

    @array_function_dispatch(_isclose_dispatcher)
    def isclose(a, b, rtol=1.e-5, atol=1.e-8, equal_nan=False):
        """
        Returns a boolean array where two arrays are element-wise equal within a
        tolerance.
    
        The tolerance values are positive, typically very small numbers.  The
        relative difference (`rtol` * abs(`b`)) and the absolute difference
        `atol` are added together to compare against the absolute difference
        between `a` and `b`.
    
        .. warning:: The default `atol` is not appropriate for comparing numbers
                     that are much smaller than one (see Notes).
    
        Parameters
        ----------
        a, b : array_like
            Input arrays to compare.
        rtol : float
            The relative tolerance parameter (see Notes).
        atol : float
            The absolute tolerance parameter (see Notes).
        equal_nan : bool
            Whether to compare NaN's as equal.  If True, NaN's in `a` will be
            considered equal to NaN's in `b` in the output array.
    
        Returns
        -------
        y : array_like
            Returns a boolean array of where `a` and `b` are equal within the
            given tolerance. If both `a` and `b` are scalars, returns a single
            boolean value.
    
        See Also
        --------
        allclose
        math.isclose
    
        Notes
        -----
        .. versionadded:: 1.7.0
    
        For finite values, isclose uses the following equation to test whether
        two floating point values are equivalent.
    
         absolute(`a` - `b`) <= (`atol` + `rtol` * absolute(`b`))
    
        Unlike the built-in `math.isclose`, the above equation is not symmetric
        in `a` and `b` -- it assumes `b` is the reference value -- so that
        `isclose(a, b)` might be different from `isclose(b, a)`. Furthermore,
        the default value of atol is not zero, and is used to determine what
        small values should be considered close to zero. The default value is
        appropriate for expected values of order unity: if the expected values
        are significantly smaller than one, it can result in false positives.
        `atol` should be carefully selected for the use case at hand. A zero value
        for `atol` will result in `False` if either `a` or `b` is zero.
    
        `isclose` is not defined for non-numeric data types.
        `bool` is considered a numeric data-type for this purpose.
    
        Examples
        --------
        >>> np.isclose([1e10,1e-7], [1.00001e10,1e-8])
        array([ True, False])
        >>> np.isclose([1e10,1e-8], [1.00001e10,1e-9])
        array([ True, True])
        >>> np.isclose([1e10,1e-8], [1.0001e10,1e-9])
        array([False,  True])
        >>> np.isclose([1.0, np.nan], [1.0, np.nan])
        array([ True, False])
        >>> np.isclose([1.0, np.nan], [1.0, np.nan], equal_nan=True)
        array([ True, True])
        >>> np.isclose([1e-8, 1e-7], [0.0, 0.0])
        array([ True, False])
        >>> np.isclose([1e-100, 1e-7], [0.0, 0.0], atol=0.0)
        array([False, False])
        >>> np.isclose([1e-10, 1e-10], [1e-20, 0.0])
        array([ True,  True])
        >>> np.isclose([1e-10, 1e-10], [1e-20, 0.999999e-10], atol=0.0)
        array([False,  True])
        """
        def within_tol(x, y, atol, rtol):
            with errstate(invalid='ignore'), _no_nep50_warning():
                return less_equal(abs(x-y), atol + rtol * abs(y))
    
        x = asanyarray(a)
        y = asanyarray(b)
    
        # Make sure y is an inexact type to avoid bad behavior on abs(MIN_INT).
        # This will cause casting of x later. Also, make sure to allow subclasses
        # (e.g., for numpy.ma).
        # NOTE: We explicitly allow timedelta, which used to work. This could
        #       possibly be deprecated. See also gh-18286.
        #       timedelta works if `atol` is an integer or also a timedelta.
        #       Although, the default tolerances are unlikely to be useful
        if y.dtype.kind != "m":
            dt = multiarray.result_type(y, 1.)
            y = asanyarray(y, dtype=dt)
    
>       xfin = isfinite(x)
E       TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''

../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/numpy/core/numeric.py:2348: TypeError
--------------------------------------- Captured stderr call ----------------------------------------
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
========================================= warnings summary ==========================================
../../../miniconda3/envs/transformers-dev/lib/python3.9/site-packages/_pytest/config/__init__.py:1439
  /home/gigant/miniconda3/envs/transformers-dev/lib/python3.9/site-packages/_pytest/config/__init__.py:1439: PytestConfigWarning: Unknown config option: doctest_glob
  
    self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")

tests/models/whisper/test_modeling_flax_whisper.py: 219 warnings
  /home/gigant/Documents/transformers_fix/transformers/tests/test_modeling_flax_common.py:795: DeprecationWarning: Please use assertEqual instead.
    self.assertEquals(type_, jnp.float32, msg=f"param {name} is not initialized in fp32.")

tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_encoder_sinusoidal_embed_positions
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_encoder_sinusoidal_embed_positions
  /home/gigant/miniconda3/envs/transformers-dev/lib/python3.9/site-packages/transformers/models/whisper/modeling_flax_whisper.py:72: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
    return jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1).astype(dtype)

tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelTest::test_equivalence_flax_to_pt
  /home/gigant/miniconda3/envs/transformers-dev/lib/python3.9/site-packages/transformers/modeling_flax_pytorch_utils.py:460: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
    pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)

tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_batched_generation
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_generation
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_logits_librispeech
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_small_en_logits_librispeech
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_tiny_en_batched_generation
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_tiny_en_generation
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_tiny_generation
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_tiny_logits_librispeech
tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_tiny_timestamp_generation
  /home/gigant/miniconda3/envs/transformers-dev/lib/python3.9/site-packages/datasets/load.py:1461: FutureWarning: The repository for hf-internal-testing/librispeech_asr_dummy contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/hf-internal-testing/librispeech_asr_dummy
  You can avoid this message in future by passing the argument `trust_remote_code=True`.
  Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
    warnings.warn(

tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_batched_generation
  /home/gigant/miniconda3/envs/transformers-dev/lib/python3.9/site-packages/librosa/core/intervals.py:8: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
    from pkg_resources import resource_filename

tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_generation_multilingual
  /home/gigant/miniconda3/envs/transformers-dev/lib/python3.9/site-packages/datasets/load.py:1461: FutureWarning: The repository for common_voice contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/common_voice
  You can avoid this message in future by passing the argument `trust_remote_code=True`.
  Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
    warnings.warn(

tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_generation_multilingual
  /home/gigant/.cache/huggingface/modules/datasets_modules/datasets/common_voice/220833898d6a60c50f621126e51fb22eb2dfe5244392c70dccd8e6e2f055f4bf/common_voice.py:634: FutureWarning: 
              This version of the Common Voice dataset is deprecated.
              You can download the latest one with
              >>> load_dataset("mozilla-foundation/common_voice_11_0", "en")
              
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
====================================== short test summary info ======================================
FAILED tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_batched_generation - AssertionError: False is not true
FAILED tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_generation_multilingual - FileNotFoundError: https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazon...
FAILED tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_large_logits_librispeech - KeyError: 'model'
FAILED tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_small_en_logits_librispeech - ValueError: not enough values to unpack (expected 2, got 1)
FAILED tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_tiny_en_generation - AssertionError: '<|st[14 chars]t|><|notimestamps|> Mr. Quilter is the apostle[84 chars]xt|>' != ...
FAILED tests/models/whisper/test_modeling_flax_whisper.py::FlaxWhisperModelIntegrationTest::test_tiny_timestamp_generation - TypeError: ufunc 'isfinite' not supported for the input types, and the inputs could not be safel...
====================== 6 failed, 68 passed, 235 warnings in 491.73s (0:08:11) =======================

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@giganttheo thank you for the detailed explanations 💛

The same 6 slow tests are failing on main, so they are not a result of this PR (cc @sanchit-gandhi)

@gante gante requested a review from amyeroberts March 20, 2024 11:59
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

@amyeroberts amyeroberts merged commit fd734be into huggingface:main Mar 21, 2024
18 checks passed
@giganttheo giganttheo deleted the fix/logits_processor_flax branch March 21, 2024 14:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Unexpected behaviour of logit processor during beam search generation in Flax
3 participants