-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Whisper] Make tests faster #24105
[Whisper] Make tests faster #24105
Conversation
pt_model.config.use_cache = False | ||
|
||
# load Flax class | ||
fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have to override this method to ensure that we init the Flax weights with the downsampled sequence length correctly (e.g. pass input_shape=init_shape
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a bit confused here. If you look at FlaxWhisperModelTest
, there is no such overriding to pass input_shape
. However, FlaxWhisperModelTester
uses the low number as in this PR.
Why we don't need to pass init_shape
in FlaxWhisperModelTest
?
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
Sorry, I am asking not because I see a test being slow, but just I saw some more Whisper test failures on daily CI, which is But yes, in general, it's best to use low number. I will take a look. |
Note that the Whisper tests have already been flagged as being slow (#23736) so this should help combat this issue! |
It's not because it's slow test that we use large value without really valid reason :-). Always better to make them use low values is the goal, unless it's absolute necessary. I still have questions on why we don't need to pass |
OK, in flax test file, I see
probably it's the reason. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @sanchit-gandhi LGTM. I convinced myself regarding the input_shape
thing.
But any comment is welcomed.
Yep agreed - the seq len was unnecessarily high here :) You're spot on regarding the init shape: we have to change this based on the sequence length since Flax Whisper initialises the positional embeddings based on the context window, so if we change the seq len (= context window) we need to init the weights with the new shape |
ed2c159
to
160a458
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
What does this PR do?
Reduces the input seq length of the Whisper tests from 1500 -> 60 frames. This in turn should speed up the tests quite considerably.