Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapting Whisper to the new loss_function attribute #36119

Open
yoadsn opened this issue Feb 10, 2025 · 3 comments
Open

Adapting Whisper to the new loss_function attribute #36119

yoadsn opened this issue Feb 10, 2025 · 3 comments
Labels
Feature request Request for a new feature Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want!

Comments

@yoadsn
Copy link

yoadsn commented Feb 10, 2025

@ArthurZucker @muellerzr Following up on #35838, #34191, #34198, #34283

I would like to help bring in Whisper into this. I see it was not included in the last #35875 round of fixes related to the loss function bug fix (grad acc.) nor the new global "loss_function" attr. Being an encodec model derived from Bart code in many places around loss and decoder input token handling - I suspect Bart would also benefit from such attention.

So - Would like to help with the following missing support:

It does not accept kwargs

In 'forward' (For Conditional Gen) - Seems straight-forward to follow @muellerzr work and implement considering test passing. Anything special to consider there?

I does not use the global "loss_function" attr (introduced with #34191)

I find that the closest Loss implementation would be ForMaskedLMLoss since seems like the shifted labels are expected from how the existing loss is calc'd

if labels is not None:
loss_fct = CrossEntropyLoss()
# move labels to correct device to enable PP
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))

Some background on the above

I find this was derived from the Bart implementation, which forced the user to either provide decoder_input_ids or derived them from labels by shifting them to the right as part of the denoising pre-training task type - this lead to a situation where labels are expected to be left shifted compared to the logits which is properly served by the above loss calculation.

Whisper, inherited that, but has a more involved input id prefixing scheme. the model is hardly the place to grab the "decoder start token id" which is required to accomplish the "shift right" of labels and get the decoder_input_ids, and anyway - for Whisper this prefix during inference is critical to determining the task and control over that is properly reflected in other args. (language, task, notimestamps etc)

Thus, proper collators suggested by @sanchit-gandhi in his great guidance and the work on Distill-Whisper have explicitly specified both labels and decoder_input_ids that worked around the auto (now unusable) "label shift righting". (See code here)

Or otherwise "cooked" the labels to contain all but the first "decode start token id" as a hack. (More at #27384) and even the Collator code in the popular blog post about Whisper FT does:

       # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

Which of course is a workaround to mitigate that Bart heritage.

WDYT @sanchit-gandhi, Did I get this right?

Anyway - this is why the ForCausalLMLoss probably won't be a fit - it will shift the labels left to match against logits positions.

Would like to know if that proper loss then to use is ForMaskedLMLoss or maybe a new ForConditionalGenerationLMLoss actually. Personally, I think a new one should exist, that does exactly what ForMaskedLMLoss with some shared implementation for both.

Also, as an aside I would love to see the Bart derived "decode_input_ids from labels" logic adapted to Whisper - but not sure I have the experience to know how.

Grad acc loss bug still applies to Whisper

As it is implemented now - you can (thankfully) customize the loss calc using "compute_loss_func" which was introduced in #34198 - and this is mandatory for anyone who want to avoid the grad acc loss described here and fixed in many PR's around the above mentioned efforts.
This is actually an open bug for Whisper which did not enjoy the common fixed_cross_entropy injection onto other models.

Thanks guys for all the great documentation on this, so much easier to try and contribute back!

@Rocketknight1
Copy link
Member

cc @eustlb

@ArthurZucker ArthurZucker added Feature request Request for a new feature Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want! labels Feb 19, 2025
@ArthurZucker
Copy link
Collaborator

Super welcome to anyone that wishes to fix it! 🤗

@yoadsn
Copy link
Author

yoadsn commented Feb 19, 2025

@ArthurZucker I would like to work on this!

I do need some guidance though - Is my analysis of the "label"->"input decode ids" problem correct?
Should I create a new "loss type" or try and force Whisper to play nice with ForCausalLMLoss ? That would require many developers to adapt their Whisper DataCollators though - and I know we do not wish to break legacy support.

That is the main issue here basically, ForCausalLMLoss expects labels to be shifted, Whisper expects labels after the shift.

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature request Request for a new feature Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want!
Projects
None yet
Development

No branches or pull requests

3 participants