-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support generating with fallback for short form audio in Whisper #30984
Support generating with fallback for short form audio in Whisper #30984
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
956cfb4
to
07e7db3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a good start @kamilakesbi. Two biggest suggestions are related to the designs of i) assisted generation, and ii) num return sequences. Think both can be simplified and assisted generation made more rigorous.
Two further design questions:
- Should we return the original
decoder_input_ids
and EOS tokens in the sequences for long-form generation as well? IMO this is an inconsistency that we return them for short-form, but not long-form, and I would be in-favour of unifying the two in this PR - Is it correct to de-activate beam search when
temperature>0
? We currently don't do this for long-form generation, but given the original Whisper repo does, it would be good to determine whether this is a 'bug' or an intended design decision
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
@ArthurZucker thanks for your review! I took your remarks into account :) Failing tests are unrelated to this PR. If this is ok for you we can perhaps merge or wait for the CI to be green... |
Let's wait for the full CI seems alright now! |
Also a question ont answered! |
The CI is green yes :) if it's ok for you I can merge! |
a00d2e8
to
6b7b3d6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Last to nits and you can merge!
What does this PR do?
The aim of this PR is to refacto the Whisper
generate
method to handle both short form and long form audio generation similarly. It will support short form audio generation with fallback (as requested in #29508).Here's what I've done:
Removed previous short-form scripts:
I've removed the part of the code used for short form generation. This involve lines 562 to 603 and lines 498 to 505 in main. Now when a short form audio (or a batched short form of audio) is passed to
generate
, it's processed by the part of the code previously used for long form generation.Use is_shortform to still distinguish between short form and long form in some cases:
In the
_postprocess_outputs
method we only returnpast_key_values
if the audios are short form. For long form audios it is too expensive. (cf. this line).In
_retrieve_max_frames_and_seek
: For long form audios, we necessarily need to pass an attention mask but not for short form audios. We can thus computemax_frames
andseek
without relying on the attention mask for short form audios.I've also updated the
split_by_batch_index
method: the previous method was broken when return_dict_in_generate was set to True for different short form audio cases. Now it handles both short form and long form audios.I've removed the
is_shortform
parameter from the inputs to the_retrieve_logit_processors
method to allow the use ofgeneration_config.no_speech_threshold
for short form audios.I've removed
is_shortfrom
parameter from the inputs to the_set_return_outputs
method to allow the use oflogprob_threshold
for short form audios.Make num_return_sequences>1 compatible with generate_with_fallback:
generate_with_fallback
can't handle num_return_sequences>1 by design. I've added a new method, called_expand_variables_for_generation
, which expands the different variables before passing intogenerate_with_fallback
whengeneration_config.num_return_sequences>1
. After expansion it will setgeneration_config.num_return_sequences
to 1 for compatibility withgenerate_with_fallback
.Ensure that the output format for short form audio is compatible with the output format in main:
The output format for long-form audio is different from that for short-form audio. In order to ensure that the output is similar to that obtained in main when processing short form audio, we need to add a few post-processing steps: This is what is done in lines 721 to 765. In particular here:
EOS
token to the output sequence as it was removed during generation with fallback.return_token_timestamps
is True in the correct format (see here).return_dict_in_generate
is True, we use the new method_stack_split_outputs
to get the output dict (containing all attributes (scores, encoder_attentions, etc.)) in the right format._stack_split_outputs
basically performs the opposite operations tosplit_by_batch_index
.Make failing slow tests to pass:
Add new tests to make sure generation with fallback works for short form audios:
I've added two tests:
test_whisper_shortform_single_batch_prev_cond
andtest_whisper_shortform_multi_batch_hard_prev_cond
.Who can review:
@sanchit-gandhi