Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FLAX] Whisper #19512

Closed
wants to merge 11 commits into from
Closed

[FLAX] Whisper #19512

wants to merge 11 commits into from

Conversation

kamalkraj
Copy link
Contributor

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@kamalkraj
Copy link
Contributor Author

Hi,

I need little clarification about implementing the FlaxWhisperDecoder Module.

What would be the best way to pass past_key_values_length to the module?

Reference in Pytorch implementation.
https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py#L863-L873

@patrickvonplaten @ydshieh @patil-suraj

@patrickvonplaten
Copy link
Contributor

Whisper on TPU will make 🔥 colab demos

@kamalkraj
Copy link
Contributor Author

Screenshot 2022-10-16 at 8 29 27 PM

@ArthurZucker
Copy link
Collaborator

Awesome work here! Feel free to ping me for a review once it is ready 😄

@kamalkraj
Copy link
Contributor Author

kamalkraj commented Oct 19, 2022

Hi,

I have finished the model and working on the test cases now.
The pt<->flax equivalence test is failing, even though the model.generate produce the exact speech-to-text like the PyTorch model.

Screenshot 2022-10-19 at 10 15 39 PM

I have attached steps to reproduce the issue in this notebook - https://colab.research.google.com/drive/1KmO8OBUpHfs1uYA_eSwamQAXnjsdbkRS?usp=sharing

Any pointers will be helpful.

Thanks

@patrickvonplaten @patil-suraj @ydshieh @ArthurZucker

@ydshieh
Copy link
Collaborator

ydshieh commented Oct 20, 2022

Hi @kamalkraj First, thank you for this awesome PR!

Regarding the PT/Flax tests, I probably need to improve that PT/Flax equivalence tests to make it (a bit) easier to find out which layers gives the larger difference.

In the meantime, I have to say there is no easy way to debug such issue. We need patience to find out at which layer(s) we have the first large difference (greater than the tolerance) and see what's wrong inside that layer.

This is usually tedious and involving manually debugging process.

Anyway, I can open a PR to make the process (a bit) easier - if you want to wait a bit. But notice that we still need similar process even that PR is merged.

@ArthurZucker
Copy link
Collaborator

Will try to get #18420 merged so that we can maybe use the find_pt_fx_differences(pt_outputs, fx_outputs) function! But in the mean time, you should set output_hidden_states=True and check where the lists differ 🤗

@ydshieh
Copy link
Collaborator

ydshieh commented Oct 20, 2022

Hi @kamalkraj Actually that test is quite good enough, but we need to change a bit to debug more.

The last 2 commit in this branch could log more information.

If you run the tests like

RUN_PT_FLAX_CROSS_TESTS=true python3 -m pytest -v tests/models/whisper/test_modeling_flax_whisper.py -k "test_equivalence_pt_to_flax"

it logs something

max diff. in outputs.logits: 0.0020506680011749268

but it doesn't fail the test -> it continues. So far, I got

E   AssertionError: <class 'list'> != <class 'tuple'> : outputs.decoder_hidden_states: Output types differ between Flax and PyTorch

so you will have to look the output type of decoder_hidden_states and make sure the type is the same as the PyTorch one.
Continue this process will eventually show you all the difference, and you can get a better idea where to debug in the modeling code.

Also, it seems when running the tests from tests/models/whisper/test_modeling_whisper.py, we have some shape issue. This is another thing to debug.

Hopefully this gives you some idea of how we can debug here 🤗

@kamalkraj
Copy link
Contributor Author

Thanks, @ydshieh and @ArthurZucker

@huggingface huggingface deleted a comment from github-actions bot Nov 18, 2022
@andyehrenberg
Copy link
Contributor

andyehrenberg commented Nov 25, 2022

To make for a more consistent API across models, couldn't we swap out past_key_values_length and instead compute position_ids to get the current positional embeddings for the decoder? It feels like this would make it easier to fit Whisper in with other finetuning codebases (no need to create custom logic for computing past_key_values_length when dealing with Whisper). As the code currently stands, I think it would actually give incorrect outputs when decoding a batch when each element of the batch has different decoder prefix/prompt tokens. Computing position ids from the attention mask would also allow for either left or right padding.

I have another Flax Whisper implementation with .from_pretrained(..., from_pt=True) working correctly and it giving correct outputs for variable length prompts that I'd be happy to share (or create a separate PR for). It also adds some stuff to the generation utilities to support prompt tokens to the decoder that already exist in the PyTorch utilities (using prompt tokens instead of model.config.decoder_start_token_id if specified).

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 25, 2022

I haven't look into this. But @andyehrenberg do you suggest a different way of computation in Flax Whisper than the one implemented in our PyTorch/TensorFlow Whisper?

It's also better for @kamalkraj to express if he would like to continue this PR before we go ahead.

@kamalkraj
Copy link
Contributor Author

@ydshieh @andyehrenberg

If there is already a working implementation, please continue.
I am closing this one.

Thanks

@kamalkraj kamalkraj closed this Nov 25, 2022
@andyehrenberg
Copy link
Contributor

@ydshieh I guess what I'm suggesting for this could also be helpful for the PyTorch/TF implementations to improve flexibility/compatibility with existing codebases that use position_ids for other models (such as when finetuning).

For example, the use-case I'm working on is fine-tuning Whisper with RL (trying to expose it to its own outputs to reduce hallucinations). At each step when collecting rollouts, it is given a batch of audio features and decoder prompts (from previous audio snippets) - these prompts are of varying lengths, so padding/attention masks are needed, and the position embeddings need to adjust accordingly. And then when doing PPO updates on these steps, the position embeddings need to be computed correctly based off of which timesteps (tokens) are padding.

The implementation in this PR wouldn't accommodate this scenario as it assumes the same past_key_values_length for each sequence in the batch, whereas the implementation I've worked on uses position_ids to keep track of where we are in each sequence of the batch. Earlier I had use a different method that only used the attention mask along with another caching method in the decoder, but using position_ids is much simpler and accommodates multiple padding schemes more simply.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants