Adapting Whisper to the new loss_function attribute #36119
Labels
Feature request
Request for a new feature
Good Second Issue
Issues that are more difficult to do than "Good First" issues - give it a try if you want!
@ArthurZucker @muellerzr Following up on #35838, #34191, #34198, #34283
I would like to help bring in Whisper into this. I see it was not included in the last #35875 round of fixes related to the loss function bug fix (grad acc.) nor the new global "loss_function" attr. Being an encodec model derived from Bart code in many places around loss and decoder input token handling - I suspect Bart would also benefit from such attention.
So - Would like to help with the following missing support:
It does not accept kwargs
In 'forward' (For Conditional Gen) - Seems straight-forward to follow @muellerzr work and implement considering test passing. Anything special to consider there?
I does not use the global "loss_function" attr (introduced with #34191)
I find that the closest Loss implementation would be ForMaskedLMLoss since seems like the shifted labels are expected from how the existing loss is calc'd
transformers/src/transformers/models/whisper/modeling_whisper.py
Lines 1787 to 1791 in d4a6b40
Some background on the above
I find this was derived from the Bart implementation, which forced the user to either provide
decoder_input_ids
or derived them from labels by shifting them to the right as part of the denoising pre-training task type - this lead to a situation where labels are expected to be left shifted compared to the logits which is properly served by the above loss calculation.Whisper, inherited that, but has a more involved input id prefixing scheme. the model is hardly the place to grab the "decoder start token id" which is required to accomplish the "shift right" of labels and get the
decoder_input_ids
, and anyway - for Whisper this prefix during inference is critical to determining the task and control over that is properly reflected in other args. (language, task, notimestamps etc)Thus, proper collators suggested by @sanchit-gandhi in his great guidance and the work on Distill-Whisper have explicitly specified both
labels
anddecoder_input_ids
that worked around the auto (now unusable) "label shift righting". (See code here)Or otherwise "cooked" the labels to contain all but the first "decode start token id" as a hack. (More at #27384) and even the Collator code in the popular blog post about Whisper FT does:
Which of course is a workaround to mitigate that Bart heritage.
WDYT @sanchit-gandhi, Did I get this right?
Anyway - this is why the
ForCausalLMLoss
probably won't be a fit - it will shift the labels left to match against logits positions.Would like to know if that proper loss then to use is
ForMaskedLMLoss
or maybe a newForConditionalGenerationLMLoss
actually. Personally, I think a new one should exist, that does exactly whatForMaskedLMLoss
with some shared implementation for both.Also, as an aside I would love to see the Bart derived "decode_input_ids from labels" logic adapted to Whisper - but not sure I have the experience to know how.
Grad acc loss bug still applies to Whisper
As it is implemented now - you can (thankfully) customize the loss calc using "compute_loss_func" which was introduced in #34198 - and this is mandatory for anyone who want to avoid the grad acc loss described here and fixed in many PR's around the above mentioned efforts.
This is actually an open bug for Whisper which did not enjoy the common
fixed_cross_entropy
injection onto other models.Thanks guys for all the great documentation on this, so much easier to try and contribute back!
The text was updated successfully, but these errors were encountered: