Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add return_token_timestamps to WhisperProcessor #30812

Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/transformers/models/whisper/feature_extraction_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def __call__(
sampling_rate: Optional[int] = None,
do_normalize: Optional[bool] = None,
device: Optional[str] = "cpu",
return_token_timestamps: Optional[bool] = None,
**kwargs,
) -> BatchFeature:
"""
Expand Down Expand Up @@ -237,6 +238,9 @@ def __call__(
device (`str`, *optional*, defaults to `'cpu'`):
Specifies the device for computation of the log-mel spectrogram of audio signals in the
`_torch_extract_fbank_features` method. (e.g., "cpu", "cuda")
return_token_timestamps (`bool`, *optional*, defaults to `None`):
Whether or not to return the number of frames of the input raw_speech.
These num_frames can be used by the model to compute word level timestamps.
"""

if sampling_rate is not None:
Expand Down Expand Up @@ -302,13 +306,17 @@ def __call__(

if isinstance(input_features[0], List):
padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]

else:
padded_inputs["input_features"] = input_features

if return_attention_mask:
# rescale from sample (48000) to feature (3000)
padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]

if return_token_timestamps is not None:
padded_inputs["num_frames"] = [len(raw_speech_i) // self.hop_length for raw_speech_i in raw_speech]

if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)

Expand Down
12 changes: 8 additions & 4 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,16 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec
# two cases:
# 1. num_frames is the same for each sample -> compute the DTW matrix for each sample in parallel
# 2. num_frames is different, compute the DTW matrix for each sample sequentially
if isinstance(num_frames, torch.Tensor):
num_frames = num_frames.to("cpu")

# we're using np.unique because num_frames can be int/list/tuple
if len(np.unique(num_frames)) == 1:
# if num_frames is the same, no need to recompute matrix, std and mean for each element of the batch
num_frames = num_frames if isinstance(num_frames, int) else num_frames[0]

weights = weights[..., : num_frames // 2]
if isinstance(num_frames, int):
weights = weights[..., : num_frames // 2]
else:
weights = weights[..., : num_frames[0] // 2]
else:
# num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences
repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames)
Expand All @@ -231,7 +234,7 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec

# Perform dynamic time warping on each element of the batch.
for batch_idx in range(batch_size):
if num_frames is not None and isinstance(num_frames, (tuple, list, np.ndarray)):
if num_frames is not None and isinstance(num_frames, (tuple, list, np.ndarray, torch.Tensor)):
matrix = weights[batch_idx, ..., : num_frames[batch_idx] // 2]

# Normalize and smoothen the weights.
Expand Down Expand Up @@ -474,6 +477,7 @@ def generate(
"The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
FutureWarning,
)

# 1. prepare generation config
generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)

Expand Down
28 changes: 16 additions & 12 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,11 +443,18 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
return_tensors="pt",
)
else:
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
if stride is None:
extra["segment_size"] = len(inputs)
if self.type == "seq2seq_whisper" and stride is None:
processed = self.feature_extractor(
inputs,
sampling_rate=self.feature_extractor.sampling_rate,
return_tensors="pt",
return_token_timestamps=True,
)
extra["num_frames"] = processed.pop("num_frames")
else:
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)

if self.torch_dtype is not None:
processed = processed.to(dtype=self.torch_dtype)
Expand All @@ -461,11 +468,11 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
attention_mask = model_inputs.pop("attention_mask", None)
stride = model_inputs.pop("stride", None)
segment_size = model_inputs.pop("segment_size", None)
num_frames = model_inputs.pop("num_frames", None)
is_last = model_inputs.pop("is_last")

if stride is not None and segment_size is not None:
raise ValueError("segment_size must be used only when stride is None")
if stride is not None and num_frames is not None:
raise ValueError("num_frames must be used only when stride is None")

if self.type in {"seq2seq", "seq2seq_whisper"}:
encoder = self.model.get_encoder()
Expand Down Expand Up @@ -495,10 +502,7 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
generate_kwargs["num_frames"] = [s[0] // self.feature_extractor.hop_length for s in stride]

else:
if isinstance(segment_size, int):
generate_kwargs["num_frames"] = segment_size // self.feature_extractor.hop_length
else:
generate_kwargs["num_frames"] = segment_size[0] // self.feature_extractor.hop_length
generate_kwargs["num_frames"] = num_frames

if self.type == "seq2seq_whisper" and inputs.shape[-1] > self.feature_extractor.nb_max_frames:
generate_kwargs["input_features"] = inputs
Expand Down
91 changes: 91 additions & 0 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1912,6 +1912,67 @@ def test_tiny_timestamp_generation(self):
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)

@slow
def test_large_timestamp_generation(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3")
model.to(torch_device)

input_speech = np.concatenate(self._load_datasamples(4))
input_features = processor(
input_speech, return_tensors="pt", sampling_rate=16_000, return_token_timestamps=True
).input_features
input_features = input_features.to(torch_device)

generated_ids = model.generate(input_features, max_length=448, return_timestamps=True).to("cpu")

EXPECTED_OUTPUT = torch.tensor([50258, 50259, 50360, 50365, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50629, 50682, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50870, 50911, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 949, 505, 11, 51245, 51287, 1034, 4680, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51494, 51523, 634, 575, 12525, 22618, 1968, 6144, 35617, 1456, 397, 266, 311, 589, 307, 534, 10281, 934, 439, 11, 51799, 51815, 50257])
self.assertTrue(torch.allclose(generated_ids, EXPECTED_OUTPUT))

EXPECTED_TRANSCRIPT = [
{
"text": (
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
" Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive"
" season of the year, with Christmas and roast beef looming before us, similes drawn from eating"
" and its results occur most readily to the mind. He has grave doubts whether Sir Frederick "
"Leighton's work is really Greek after all,"
),
"offsets": [
{
"text": (
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
),
"timestamp": (0.0, 5.28),
},
{
"text": " Nor is Mr. Quilter's manner less interesting than his matter.",
"timestamp": (6.34, 10.1),
},
{
"text": (
" He tells us that at this festive season of the year, with Christmas and roast beef looming before us,"
),
"timestamp": (10.92, 17.6),
},
{
"text": (" similes drawn from eating and its results occur most readily to the mind."),
"timestamp": (18.44, 22.580000000000002),
},
{
"text": (
" He has grave doubts whether Sir Frederick Leighton's work is really Greek after all,"
),
"timestamp": (23.16, 28.68),
},
],
}
]

transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)

@slow
def test_tiny_token_timestamp_generation(self):
set_seed(0)
Expand Down Expand Up @@ -1941,6 +2002,36 @@ def test_tiny_token_timestamp_generation(self):

self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))

@slow
def test_large_token_timestamp_generation(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3")
model.to(torch_device)

input_speech = self._load_datasamples(4)
input_features = processor(
input_speech, return_tensors="pt", sampling_rate=16_000, return_token_timestamps=True
)
input_features = input_features.to(torch_device)

generate_outputs = model.generate(
**input_features, max_length=448, return_timestamps=True, return_token_timestamps=True
)

self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape)

# fmt: off
EXPECTED_OUTPUT = torch.tensor([
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6200, 0.7400, 0.8600, 1.0000, 1.0400, 1.3000, 1.4400, 1.7800, 2.1800, 2.2800, 2.5000, 2.9200, 3.0000, 3.3800, 3.5000, 3.6000, 3.8400, 4.1000, 4.4000, 4.6800, 5.1400, 5.3600, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6000, 0.9200, 1.2200, 1.3400, 1.4200, 1.5400, 1.5800, 1.7400, 2.0600, 2.3800, 3.0400, 3.3800, 3.6400, 4.1200, 4.3600, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5400, 0.8200, 1.1600, 1.4600, 1.7400, 1.8800, 2.3400, 2.7400, 3.1400, 3.2200, 3.5400, 4.2800, 4.5600, 4.8200, 5.0600, 5.3200, 5.6600, 5.9600, 6.1400, 6.4000, 6.8400, 7.8800, 8.0200, 8.3600, 8.7000, 9.0200, 9.3200, 9.5000, 9.8400, 10.3000, 10.6600, 11.0800, 11.3600, 11.4600, 11.8000, 12.4600],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5600, 0.7600, 1.0600, 1.4000, 1.8800, 2.2600, 2.6200, 2.8000, 2.9600, 3.0000, 3.2000, 3.4400, 3.6800, 4.0000, 4.6000, 5.0000, 5.3200, 5.4800, 6.0600, 6.0600, 6.1000, 6.3200, 6.7400, 7.0000, 7.2200, 7.4000, 7.7600, 8.0600, 8.5600, 8.8600, 8.9400, 9.1000, 9.3400, 9.8800, 9.8800, 9.8800]
])
# fmt: on

self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))

@slow
def test_tiny_token_timestamp_batch_generation(self):
set_seed(0)
Expand Down
Loading