-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add return_token_timestamps to WhisperProcessor #30812
add return_token_timestamps to WhisperProcessor #30812
Conversation
return_num_frames
in WhisperProcessor
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think this is indeed the cleanest and most reliable approach for computing num_frames
. The alternative method we discussed offline is detailed below. Leaving it here for the next reviewer to consider, in case they believe it's a superior strategy.
Anything beyond len(input_speech)
is padded by zeros to 30-seconds in the feature extractor. If we know what zero’s correspond to in log-mel space, then we can know how many padded zeros we have in our spectrogram, and thus what the original input length was.
Note that this won’t be perfect: the last frame where the audio stops is going to be affected by the end of the audio, so we’ll be looking for the first frame where there is entirely padding (rather than finding the frame in which the audio stops).
However, the original method by OpenAI (and the one implemented in this PR) is also imperfect: if a user took a 10-second audio, and padded it by hand to 15-seconds with zeros, then num_frames
would be computed on the length of the padded input, not the original one
@@ -474,6 +475,13 @@ def generate( | |||
"The input name `inputs` is deprecated. Please make sure to use `input_features` instead.", | |||
FutureWarning, | |||
) | |||
|
|||
if input_features is not None and isinstance(input_features, BatchFeature): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why this has crept in? input_features
should be a tensor of shape (bsz, num_mels, num_frames)
, not a BatchFeature
encoding. Thus, this new logic isn't required.
The correct way of using the feature extractor should be:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset, Audio
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
dataset = dataset.cast_column("audio", Audio(16_000))
sample = next(iter(dataset))
inputs = processor(sample["audio"]["array"], return_tensors="pt")
# note here how we un-pack the batch feature encoding
pred_ids = model.generate(**inputs, language="english")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output of the processor would be a BatchFeature
as indicated here no ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but then we un-pack the BatchFeature
when we pass it to the model, i.e. we do:
pred_ids = model.generate(**inputs)
Not:
pred_ids = model.generate(inputs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case it will work with both packed and unpacked inputs. Isn't that better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm aligned with @sanchit-gandhi here - handling packed and unpacked inputs isn't something any of our other processing classes handle, so it's not something we need to introduce here
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Mostly formatting now, then we can get a final review
@@ -474,6 +475,13 @@ def generate( | |||
"The input name `inputs` is deprecated. Please make sure to use `input_features` instead.", | |||
FutureWarning, | |||
) | |||
|
|||
if input_features is not None and isinstance(input_features, BatchFeature): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but then we un-pack the BatchFeature
when we pass it to the model, i.e. we do:
pred_ids = model.generate(**inputs)
Not:
pred_ids = model.generate(inputs)
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this feature and tests!
All looks good to me - just the handling of unpacked features to remove
@@ -474,6 +475,13 @@ def generate( | |||
"The input name `inputs` is deprecated. Please make sure to use `input_features` instead.", | |||
FutureWarning, | |||
) | |||
|
|||
if input_features is not None and isinstance(input_features, BatchFeature): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm aligned with @sanchit-gandhi here - handling packed and unpacked inputs isn't something any of our other processing classes handle, so it's not something we need to introduce here
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding!
…b.com:kamilakesbi/transformers into timestamps_whisper_for_conditional_generation
@@ -1927,7 +1927,117 @@ def test_large_timestamp_generation(self): | |||
|
|||
generated_ids = model.generate(input_features, max_length=448, return_timestamps=True).to("cpu") | |||
|
|||
EXPECTED_OUTPUT = torch.tensor([50258, 50259, 50360, 50365, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50629, 50682, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50870, 50911, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 949, 505, 11, 51245, 51287, 1034, 4680, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51494, 51523, 634, 575, 12525, 22618, 1968, 6144, 35617, 1456, 397, 266, 311, 589, 307, 534, 10281, 934, 439, 11, 51799, 51815, 50257]) | |||
EXPECTED_OUTPUT = torch.tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we want to split across the lines like this.
You can wrap EXPECTED_OUTPUT
around # fmt: off
and fmt: on
comments to avoid this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok thanks for the tips! will be useful ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can also use # fmt: skip
for single lines, c.f. the previous comment #30812 (comment)
cc @amyeroberts @sanchit-gandhi Could you please merge this PR as I don't have the rights to do so? |
What this PR do ?
This PR fixes #30433 by making sure we can compute timestamps with both
WhisperForConditionalGeneration
andAutomaticSpeechRecognitionPipeline
.We add a
return_timestamps
hyperparameter toWhisperProcessor.feature_extractor
to be used when we want to compute timestamps. When True, the processor will return anum_frames
parameter containing the number of frames of the input audios.num_frames
is then passed togenerate
and used to compute timestamps.Prior to that, timestamps were broken for whisper-large-v3 when used with WhisperForConditionalGeneration.
Who can review ?
cc @sanchit-gandhi