Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add return_token_timestamps to WhisperProcessor #30812

Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/transformers/models/whisper/feature_extraction_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def __call__(
sampling_rate: Optional[int] = None,
do_normalize: Optional[bool] = None,
device: Optional[str] = "cpu",
return_timestamps: Optional[int] = None,
**kwargs,
) -> BatchFeature:
"""
Expand Down Expand Up @@ -302,6 +303,7 @@ def __call__(

if isinstance(input_features[0], List):
padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]

else:
padded_inputs["input_features"] = input_features

Expand All @@ -312,4 +314,7 @@ def __call__(
if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)

if return_timestamps is not None:
padded_inputs["num_frames"] = [len(raw_speech[i]) // self.hop_length for i in range(len(raw_speech))]

return padded_inputs
8 changes: 8 additions & 0 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch.nn.functional as F
from torch import nn

from ...feature_extraction_utils import BatchFeature
from ...generation.configuration_utils import GenerationConfig
from ...generation.logits_process import (
LogitsProcessorList,
Expand Down Expand Up @@ -474,6 +475,13 @@ def generate(
"The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
FutureWarning,
)

if input_features is not None and isinstance(input_features, BatchFeature):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why this has crept in? input_features should be a tensor of shape (bsz, num_mels, num_frames), not a BatchFeature encoding. Thus, this new logic isn't required.

The correct way of using the feature extractor should be:

from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset, Audio

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
dataset = dataset.cast_column("audio", Audio(16_000))

sample = next(iter(dataset))
inputs = processor(sample["audio"]["array"], return_tensors="pt")

# note here how we un-pack the batch feature encoding
pred_ids = model.generate(**inputs, language="english")

Copy link
Contributor Author

@kamilakesbi kamilakesbi May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output of the processor would be a BatchFeature as indicated here no ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but then we un-pack the BatchFeature when we pass it to the model, i.e. we do:

pred_ids = model.generate(**inputs)

Not:

pred_ids = model.generate(inputs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case it will work with both packed and unpacked inputs. Isn't that better?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm aligned with @sanchit-gandhi here - handling packed and unpacked inputs isn't something any of our other processing classes handle, so it's not something we need to introduce here

if "num_frames" in input_features.keys():
kwargs["num_frames"] = input_features.pop("num_frames")
if "input_features" in input_features.keys():
input_features = input_features.input_features

# 1. prepare generation config
generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)

Expand Down
28 changes: 16 additions & 12 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,11 +443,18 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
return_tensors="pt",
)
else:
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
if stride is None:
extra["segment_size"] = len(inputs)
if self.type == "seq2seq_whisper" and stride is None:
processed = self.feature_extractor(
inputs,
sampling_rate=self.feature_extractor.sampling_rate,
return_tensors="pt",
return_timestamps=True,
)
extra["num_frames"] = processed.pop("num_frames")
else:
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)

if self.torch_dtype is not None:
processed = processed.to(dtype=self.torch_dtype)
Expand All @@ -461,11 +468,11 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
attention_mask = model_inputs.pop("attention_mask", None)
stride = model_inputs.pop("stride", None)
segment_size = model_inputs.pop("segment_size", None)
num_frames = model_inputs.pop("num_frames", None)
is_last = model_inputs.pop("is_last")

if stride is not None and segment_size is not None:
raise ValueError("segment_size must be used only when stride is None")
if stride is not None and num_frames is not None:
raise ValueError("num_frames must be used only when stride is None")

if self.type in {"seq2seq", "seq2seq_whisper"}:
encoder = self.model.get_encoder()
Expand Down Expand Up @@ -495,10 +502,7 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
generate_kwargs["num_frames"] = [s[0] // self.feature_extractor.hop_length for s in stride]

else:
if isinstance(segment_size, int):
generate_kwargs["num_frames"] = segment_size // self.feature_extractor.hop_length
else:
generate_kwargs["num_frames"] = segment_size[0] // self.feature_extractor.hop_length
generate_kwargs["num_frames"] = num_frames

if self.type == "seq2seq_whisper" and inputs.shape[-1] > self.feature_extractor.nb_max_frames:
generate_kwargs["input_features"] = inputs
Expand Down
Loading