Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing timestamp offset using Whisper with pipeline and sequential decoding #34210

Open
2 of 4 tasks
dintifla opened this issue Oct 17, 2024 · 11 comments · May be fixed by #35750
Open
2 of 4 tasks

Missing timestamp offset using Whisper with pipeline and sequential decoding #34210

dintifla opened this issue Oct 17, 2024 · 11 comments · May be fixed by #35750
Labels

Comments

@dintifla
Copy link

dintifla commented Oct 17, 2024

System Info

  • transformers version: 4.45.2
  • Platform: macOS-15.0.1-arm64-arm-64bit
  • Python version: 3.12.1
  • Huggingface_hub version: 0.23.3
  • Safetensors version: 0.4.3
  • Accelerate version: 0.34.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.1 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: no

Who can help?

@Rocketknight1 @gante @ylacombe

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. pip install transformers==4.45.2

  2. Setup a Whisper pipeline using chunk_length_s=0 (which is sequential long-form decoding according to the model card (at least for large-v3)) and return_timestamps=True

  3. Transcribe an audio longer than 30s

    from transformers import pipeline
    import torch
    
    audio_file = '<an-audio-file-longer-than-30-s>'
    chunked = False
    
    pipe = pipeline(
        'automatic-speech-recognition',
        model='openai/whisper-small',
        chunk_length_s=30 if chunked else 0,
        return_timestamps=True,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu',
    )
    
    result = pipe(audio_file)
    transcript = '\n'.join(
        f"({chunk['timestamp'][0]}, {chunk['timestamp'][1]})\t{chunk['text']}" for chunk in result['chunks']
    )
    print(transcript)
  4. See that the timestamps start at 0.0s after 30s

    (0.0, 4.44)      Er hatte schon mal eine Schnauze voll von allem und jedem.
    (4.44, 6.28)     Und er hat den Schluss getroffen.
    (6.28, 7.8)      Es hilft nichts mehr.
    (7.8, 9.28)      Ich wandere aus.
    (9.28, 11.4)     Das kann ein Grund sein,
    (11.4, 14.48)    wieso er eine Heimat für immer der Rückenträger will.
    (14.48, 16.72)   Oder es ist etwas ganz anderes.
    (16.72, 19.24)   Der wohl bekannt ist Grund...
    (19.24, 20.36)  ... die Liebe.
    (20.36, 22.44)   So ist es bei Hans Muster.
    (22.44, 24.72)   Die Liebe hat ihn nach Deutschland gezogen.
    (24.72, 27.0)    Und dort ist er seit vier Jahren.
    (27.0, 29.4)     Aber welter der für immer dort bleibt.
    (0.0, 1.0)       Gute Frage.
    (1.0, 4.0)       Ich stelle mir einen Gart am Viertel vor im PO bei den Leuten.
    (4.0, 7.0)       Und bis dort her, mein Name ist Peter Müller.
    (7.0, 11.0)      Und ich bin Wassermelone Heines vom Harry Styles.
    

Expected behavior

The timestamps should be correct, also if the audio is longer than 30s (as if the chunked-algorithm is used):

(0.0, 4.44)      Er hatte schon mal eine Schnauze voll von allem und jedem.
(4.44, 6.28)     Und er hat den Schluss getroffen.
(6.28, 7.8)      Es hilft nichts mehr.
(7.8, 9.28)      Ich wandere aus.
(9.28, 11.4)     Das kann ein Grund sein,
(11.4, 14.48)    wieso er eine Heimat für immer der Rückenträger will.
(14.48, 16.72)   Oder es ist etwas ganz anderes.
(16.72, 19.24)   Der wohl bekannt ist Grund...
(19.24, 20.36)  ... die Liebe.
(20.36, 22.44)   So ist es bei Hans Muster.
(22.44, 24.72)   Die Liebe hat ihn nach Deutschland gezogen.
(24.72, 26.0)    Und dort ist er seit vier Jahren.
(26.0, 29.0)     Aber welter der für immer dort bleibt, gute Frage.
(29.0, 32.0)     Wir stellen es dir an, am Viertel vor, im PO bei den Leuten.
(32.0, 35.0)     Und bis dort her, mein Name ist Peter Müller.
(35.0, 39.0)     Und jetzt ein Wassermelon Heines vom Harry Styles.

The output is from above script using chunked=True

@dintifla dintifla added the bug label Oct 17, 2024
@gante
Copy link
Member

gante commented Oct 17, 2024

cc @eustlb, since you're working on other Whisper fixes :)

@dintifla dintifla changed the title Timestamps using Whisper with pipeline and sequential decoding start at 0.0 after 30s Missing timestamp offset using Whisper with pipeline and sequential decoding Oct 17, 2024
@dineshveguru
Copy link

Any fix ?

@Rocketknight1
Copy link
Member

@ylacombe can you see a quick fix here? If it's more of a pipeline issue and you don't have any ideas, let me know and I can take it

@ylacombe
Copy link
Contributor

Hey @dintifla and @dineshveguru , thanks for your message.

cc @eustlb, seems linked to #34537 but it's not exactly the same issue, any ideas why it happen ?

@dintifla
Copy link
Author

dintifla commented Nov 26, 2024

Hey @ylacombe

As far as I analyzed it is because stride is missing in the output here:

hence, time_offset is always 0.0

the calculation in your linked PR works differently.

@ylacombe
Copy link
Contributor

Indeed @dintifla, you're right on the cause, I'll dig into this tomorrow

@ylacombe
Copy link
Contributor

ylacombe commented Nov 27, 2024

Turns out it's not because stride is missing. stride is only present if we use chunk_length_s>0. In other words, it's only used with the general long-form transcription algorithm from the ASR pipeline.

The error you observed actually happens because time_offset is always zero here if chunk_length_s=0.

To fix this, we have to be careful about how we set the time offset. @eustlb, you've worked a lot on this in #34537. Could you a TL;DR on the different cases we can observe? Notably, it seems that we're facing the case:

<0> t1 t2 <T1> <T1> t3 t4 t5 ... tn <TN> <0> tn+1 tn+1 <T1_bis> <T1_bis>...

where tX denote text tokens and <TX> denote timestamp tokens. Are there other possible cases (other than <TN> <0>)? Does <TN> <0> fall into the 1st or the 2nd cases described in #34537 ?

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@dintifla
Copy link
Author

dintifla commented Jan 7, 2025

This issue still persists (v4.47.1). Please re-open it.
Can I somehow assist @ylacombe ?

@eustlb
Copy link
Contributor

eustlb commented Jan 17, 2025

Fixed in #35750 that will be merged ASAP! Thanks a lot for raising this issue, and thanks a lot for your patience 🤗

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants