-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FLAX] Whisper #19512
[FLAX] Whisper #19512
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
Hi, I need little clarification about implementing the What would be the best way to pass Reference in Pytorch implementation. |
Whisper on TPU will make 🔥 colab demos |
Awesome work here! Feel free to ping me for a review once it is ready 😄 |
Hi, I have finished the model and working on the test cases now. I have attached steps to reproduce the issue in this notebook - https://colab.research.google.com/drive/1KmO8OBUpHfs1uYA_eSwamQAXnjsdbkRS?usp=sharing Any pointers will be helpful. Thanks |
Hi @kamalkraj First, thank you for this awesome PR! Regarding the PT/Flax tests, I probably need to improve that PT/Flax equivalence tests to make it (a bit) easier to find out which layers gives the larger difference. In the meantime, I have to say there is no easy way to debug such issue. We need patience to find out at which layer(s) we have the first large difference (greater than the tolerance) and see what's wrong inside that layer. This is usually tedious and involving manually debugging process. Anyway, I can open a PR to make the process (a bit) easier - if you want to wait a bit. But notice that we still need similar process even that PR is merged. |
Will try to get #18420 merged so that we can maybe use the |
Hi @kamalkraj Actually that test is quite good enough, but we need to change a bit to debug more. The last 2 commit in this branch could log more information. If you run the tests like RUN_PT_FLAX_CROSS_TESTS=true python3 -m pytest -v tests/models/whisper/test_modeling_flax_whisper.py -k "test_equivalence_pt_to_flax" it logs something
but it doesn't fail the test -> it continues. So far, I got E AssertionError: <class 'list'> != <class 'tuple'> : outputs.decoder_hidden_states: Output types differ between Flax and PyTorch so you will have to look the output type of Also, it seems when running the tests from Hopefully this gives you some idea of how we can debug here 🤗 |
Thanks, @ydshieh and @ArthurZucker |
To make for a more consistent API across models, couldn't we swap out I have another Flax Whisper implementation with .from_pretrained(..., from_pt=True) working correctly and it giving correct outputs for variable length prompts that I'd be happy to share (or create a separate PR for). It also adds some stuff to the generation utilities to support prompt tokens to the decoder that already exist in the PyTorch utilities (using prompt tokens instead of |
I haven't look into this. But @andyehrenberg do you suggest a different way of computation in Flax Whisper than the one implemented in our PyTorch/TensorFlow Whisper? It's also better for @kamalkraj to express if he would like to continue this PR before we go ahead. |
If there is already a working implementation, please continue. Thanks |
@ydshieh I guess what I'm suggesting for this could also be helpful for the PyTorch/TF implementations to improve flexibility/compatibility with existing codebases that use For example, the use-case I'm working on is fine-tuning Whisper with RL (trying to expose it to its own outputs to reduce hallucinations). At each step when collecting rollouts, it is given a batch of audio features and decoder prompts (from previous audio snippets) - these prompts are of varying lengths, so padding/attention masks are needed, and the position embeddings need to adjust accordingly. And then when doing PPO updates on these steps, the position embeddings need to be computed correctly based off of which timesteps (tokens) are padding. The implementation in this PR wouldn't accommodate this scenario as it assumes the same |
What does this PR do?
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.