Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct Whisper's beam search scores computation #32336

Merged
merged 2 commits into from
Sep 12, 2024

Conversation

ylacombe
Copy link
Contributor

Fixes #32246

There have been many failing tests these past days with Whisper, so I'd probably wait for them to be fixed before merging this PR.


What does this PR do?

@cifkao made a great summary the current issue in #32246:

TL;DR: Scores corresponding to the wrong sequence in the batch/beam are returned.

He also rightfully identified what was the origin of the issue:

The bug seems to be here in _postprocess_outputs. This works fine with num_beams==1, but with num_beams>1, the shape of the items in seek_outputs["scores"] will be [num_beams * batch_size, vocab_size], while the code expects it to be [batch_size, vocab_size]. Therefore, instead of choosing the correct sequence in the beam/batch, this code will incorrectly combine scores from different sequences.

The solution simply consists in taking the right logits_scores for each of the generated tokens.
Instead of taking the batch_idx-th logits_scores out of the num_beams * batch, we're now taking the beam_idx-th logits_scores.

Reproduction results

I've recomputed the code snippet from #32246.

How to read the results:

The first set of scores are the scores corresponding to each generated tokens, as well as their beam index.
The second set of scores are the scores of a handmade forward pass of the generated tokens, they indicates the "true scores" that we should have.

Notice how in #32246, the scores coming from the 1-th beam index are different from the recomputed scores. It indicates that we selected the wrong scores.

Here, they're about the same, which indicates we selected the right beam indices.

Scores out of the generation:

('<|0.00|>', -0.06171704828739166, 0)
(' Folks', -1.9032700061798096, 0)
(',', -0.40583235025405884, 0)
(' if', -0.03763910010457039, 0)
(' you', -0.0019693044014275074, 0)
(' watch', -0.14575302600860596, 0)
(' the', -0.2036631554365158, 0)
(' show', -0.002341626212000847, 0)
(',', -0.2806797921657562, 0)
(' you', -0.290231853723526, 0)
(' know', -0.025554247200489044, 0)
(' I', -1.0598242282867432, 0)
(' spent', -0.5059170722961426, 1)
(' a', -0.02328178472816944, 1)
(' lot', -0.02414931170642376, 1)
(' of', -0.02351410686969757, 1)
(' time', -0.015122056938707829, 1)
(' right', -1.1174389123916626, 1)
(' over', -0.020583242177963257, 1)
(' there', -0.031000398099422455, 0)
('.', -0.23914632201194763, 0)
('<|5.12|>', -3.7109971046447754, 0)

Scores out of the forward:

('<|en|>', -0.3857421875)
('<|transcribe|>', -6.556510925292969e-06)
('<|0.00|>', -0.1939697265625)
(' Folks', -1.931640625)
(',', -0.40966796875)
(' if', -0.0380859375)
(' you', -0.002063751220703125)
(' watch', -0.1456298828125)
(' the', -0.2041015625)
(' show', -0.00235748291015625)
(',', -0.283447265625)
(' you', -0.2978515625)
(' know', -0.0259857177734375)
(' I', -1.0849609375)
(' spent', -0.499755859375)
(' a', -0.0235137939453125)
(' lot', -0.02398681640625)
(' of', -0.0230560302734375)
(' time', -0.0152587890625)
(' right', -1.12109375)
(' over', -0.0210113525390625)
(' there', -0.030975341796875)
('.', -0.2425537109375)
('<|5.12|>', -3.802734375)
...

** Code:**

from datasets import Audio, load_dataset
from transformers import WhisperForConditionalGeneration, AutoProcessor
import torch
import numpy as np

model = WhisperForConditionalGeneration.from_pretrained(
    "openai/whisper-tiny", torch_dtype=torch.float16
)
processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
model.cuda()

ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
audio = ds[0]["audio"]["array"].astype(np.float32)
inputs = processor(
    [audio],
    return_tensors="pt",
    truncation=False,
    padding="longest",
    sampling_rate=16_000,
)
inputs = inputs.to(model.device, torch.float16)

generation_output = model.generate(
    **inputs,
    language="en",
    return_timestamps=True,
    return_segments=True,
    output_scores=True,
    num_beams=2,
    # num_return_sequences=1,
    temperature=0.0,
    logprob_threshold=0.0,
    compression_ratio_threshold=2.4,
    no_speech_threshold=0.6,
)

# Print each token along with its log-probability and beam index
segment = generation_output["segments"][0][0]
tokens = segment["result"]["sequences"]
scores = segment["result"]["scores"]
beam_indices = segment["result"]["beam_indices"]
logprobs = torch.as_tensor([s.float().log_softmax(-1)[t] for s, t in zip(scores, segment["tokens"])])
print(*[(processor.tokenizer.decode([t], decode_with_timestamps=True), s.item(), b.item()) for s, t, b in zip(logprobs, tokens, beam_indices)], sep="\n")


# Now run a forward pass with the generated tokens
inputs_forward = {k: v[..., :3000].cuda() for k, v in inputs.items()}
inputs_forward["decoder_input_ids"] = torch.cat(
    [
        torch.as_tensor(processor.tokenizer.encode("<|startoftranscript|><|en|><|transcribe|>", add_special_tokens=False)),
        tokens,
    ],
)[None].cuda()

with torch.inference_mode():
    output_forward = model(**inputs_forward)

# Print each token along with its log-probability
print(*[(processor.tokenizer.decode([t], decode_with_timestamps=True), s[t].item()) for s, t in zip(torch.nn.functional.log_softmax(
                output_forward.logits.squeeze(0), dim=-1
            ), inputs_forward["decoder_input_ids"].squeeze(0)[1:])], sep="\n")

cc @LysandreJik, @kamilakesbi and @sanchit-gandhi

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@cifkao
Copy link
Contributor

cifkao commented Jul 31, 2024

Nice, that looks like the correct fix to me!

I suspect that the other items (attentions, hidden states, logits) will have size num_beams * batch_size too though, so they might require indexing by beam_idx instead of batch_idx as well?

(Also, for anyone wondering why the scores are not exactly the same, it's likely because of the logits processors SuppressTokensLogitsProcessor and WhisperTimeStampLogitsProcessor, which suppress certain tokens.)

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks for fixing @ylacombe and for the speedy review @cifkao!

Would be great if you could wrap the code reproducer into a slow test so we maintain correctness going forwards :)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we merge this and close the issue?

@ylacombe
Copy link
Contributor Author

Merging now

@ylacombe ylacombe merged commit 8f8af0f into huggingface:main Sep 12, 2024
16 checks passed
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Incorrect scores returned in Whisper with num_beams>1
5 participants