Skip to content

Commit

Permalink
add return_token_timestamps to WhisperProcessor (#30812)
Browse files Browse the repository at this point in the history
* compute num_frames in WhisperFeatureExtractor

* add return_num_frames in WhisperFeatureProcessor + adapt pipeline

* return_timestamps renaming + pipeline fix

* fix

* fix

* fix

* add tests

* Update src/transformers/models/whisper/feature_extraction_whisper.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* apply review changes

* fix

* Update src/transformers/models/whisper/feature_extraction_whisper.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update tests/models/whisper/test_modeling_whisper.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* apply review

* fix

* review changes

* Update src/transformers/models/whisper/feature_extraction_whisper.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* make style quality

* EXPECTED_OUTPUT in single line

* small numpy->torch fix

* fix

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
3 people authored May 20, 2024
1 parent 66b0d9e commit 1c2bb3a
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 17 deletions.
8 changes: 8 additions & 0 deletions src/transformers/models/whisper/feature_extraction_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def __call__(
sampling_rate: Optional[int] = None,
do_normalize: Optional[bool] = None,
device: Optional[str] = "cpu",
return_token_timestamps: Optional[bool] = None,
**kwargs,
) -> BatchFeature:
"""
Expand Down Expand Up @@ -237,6 +238,9 @@ def __call__(
device (`str`, *optional*, defaults to `'cpu'`):
Specifies the device for computation of the log-mel spectrogram of audio signals in the
`_torch_extract_fbank_features` method. (e.g., "cpu", "cuda")
return_token_timestamps (`bool`, *optional*, defaults to `None`):
Whether or not to return the number of frames of the input raw_speech.
These num_frames can be used by the model to compute word level timestamps.
"""

if sampling_rate is not None:
Expand Down Expand Up @@ -302,13 +306,17 @@ def __call__(

if isinstance(input_features[0], List):
padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]

else:
padded_inputs["input_features"] = input_features

if return_attention_mask:
# rescale from sample (48000) to feature (3000)
padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]

if return_token_timestamps is not None:
padded_inputs["num_frames"] = [len(raw_speech_i) // self.hop_length for raw_speech_i in raw_speech]

if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)

Expand Down
15 changes: 10 additions & 5 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,15 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec
# 2. num_frames is different, compute the DTW matrix for each sample sequentially

# we're using np.unique because num_frames can be int/list/tuple
if len(np.unique(num_frames)) == 1:
# if num_frames is the same, no need to recompute matrix, std and mean for each element of the batch
num_frames = num_frames if isinstance(num_frames, int) else num_frames[0]

if isinstance(num_frames, int):
weights = weights[..., : num_frames // 2]

elif isinstance(num_frames, (list, tuple, np.ndarray)) and len(np.unique(num_frames)) == 1:
weights = weights[..., : num_frames[0] // 2]

elif isinstance(num_frames, (torch.Tensor)) and len(torch.unique(num_frames)) == 1:
weights = weights[..., : num_frames[0] // 2]

else:
# num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences
repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames)
Expand All @@ -231,7 +235,7 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec

# Perform dynamic time warping on each element of the batch.
for batch_idx in range(batch_size):
if num_frames is not None and isinstance(num_frames, (tuple, list, np.ndarray)):
if num_frames is not None and isinstance(num_frames, (tuple, list, np.ndarray, torch.Tensor)):
matrix = weights[batch_idx, ..., : num_frames[batch_idx] // 2]

# Normalize and smoothen the weights.
Expand Down Expand Up @@ -475,6 +479,7 @@ def generate(
"The input name `inputs` is deprecated. Please make sure to use `input_features` instead.",
FutureWarning,
)

# 1. prepare generation config
generation_config, kwargs = self._prepare_generation_config(generation_config, **kwargs)

Expand Down
28 changes: 16 additions & 12 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,11 +443,18 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
return_tensors="pt",
)
else:
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
if stride is None:
extra["segment_size"] = len(inputs)
if self.type == "seq2seq_whisper" and stride is None:
processed = self.feature_extractor(
inputs,
sampling_rate=self.feature_extractor.sampling_rate,
return_tensors="pt",
return_token_timestamps=True,
)
extra["num_frames"] = processed.pop("num_frames")
else:
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)

if self.torch_dtype is not None:
processed = processed.to(dtype=self.torch_dtype)
Expand All @@ -461,11 +468,11 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
attention_mask = model_inputs.pop("attention_mask", None)
stride = model_inputs.pop("stride", None)
segment_size = model_inputs.pop("segment_size", None)
num_frames = model_inputs.pop("num_frames", None)
is_last = model_inputs.pop("is_last")

if stride is not None and segment_size is not None:
raise ValueError("segment_size must be used only when stride is None")
if stride is not None and num_frames is not None:
raise ValueError("num_frames must be used only when stride is None")

if self.type in {"seq2seq", "seq2seq_whisper"}:
encoder = self.model.get_encoder()
Expand Down Expand Up @@ -495,10 +502,7 @@ def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
generate_kwargs["num_frames"] = [s[0] // self.feature_extractor.hop_length for s in stride]

else:
if isinstance(segment_size, int):
generate_kwargs["num_frames"] = segment_size // self.feature_extractor.hop_length
else:
generate_kwargs["num_frames"] = segment_size[0] // self.feature_extractor.hop_length
generate_kwargs["num_frames"] = num_frames

if self.type == "seq2seq_whisper" and inputs.shape[-1] > self.feature_extractor.nb_max_frames:
generate_kwargs["input_features"] = inputs
Expand Down
93 changes: 93 additions & 0 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1950,6 +1950,69 @@ def test_tiny_timestamp_generation(self):
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)

@slow
def test_large_timestamp_generation(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3")
model.to(torch_device)

input_speech = np.concatenate(self._load_datasamples(4))
input_features = processor(
input_speech, return_tensors="pt", sampling_rate=16_000, return_token_timestamps=True
).input_features
input_features = input_features.to(torch_device)

generated_ids = model.generate(input_features, max_length=448, return_timestamps=True).to("cpu")

# fmt: off
EXPECTED_OUTPUT = torch.tensor([50258, 50259, 50360, 50365, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50629, 50682, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50870, 50911, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 949, 505, 11, 51245, 51287, 1034, 4680, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51494, 51523, 634, 575, 12525, 22618, 1968, 6144, 35617, 1456, 397, 266, 311, 589, 307, 534, 10281, 934, 439, 11, 51799, 51815, 50257])
# fmt: on
self.assertTrue(torch.allclose(generated_ids, EXPECTED_OUTPUT))

EXPECTED_TRANSCRIPT = [
{
"text": (
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
" Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive"
" season of the year, with Christmas and roast beef looming before us, similes drawn from eating"
" and its results occur most readily to the mind. He has grave doubts whether Sir Frederick "
"Leighton's work is really Greek after all,"
),
"offsets": [
{
"text": (
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
),
"timestamp": (0.0, 5.28),
},
{
"text": " Nor is Mr. Quilter's manner less interesting than his matter.",
"timestamp": (6.34, 10.1),
},
{
"text": (
" He tells us that at this festive season of the year, with Christmas and roast beef looming before us,"
),
"timestamp": (10.92, 17.6),
},
{
"text": (" similes drawn from eating and its results occur most readily to the mind."),
"timestamp": (18.44, 22.580000000000002),
},
{
"text": (
" He has grave doubts whether Sir Frederick Leighton's work is really Greek after all,"
),
"timestamp": (23.16, 28.68),
},
],
}
]

transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)

@slow
def test_tiny_token_timestamp_generation(self):
set_seed(0)
Expand Down Expand Up @@ -1979,6 +2042,36 @@ def test_tiny_token_timestamp_generation(self):

self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))

@slow
def test_large_token_timestamp_generation(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3")
model.to(torch_device)

input_speech = self._load_datasamples(4)
input_features = processor(
input_speech, return_tensors="pt", sampling_rate=16_000, return_token_timestamps=True
)
input_features = input_features.to(torch_device)

generate_outputs = model.generate(
**input_features, max_length=448, return_timestamps=True, return_token_timestamps=True
)

self.assertEqual(generate_outputs.sequences.shape, generate_outputs.token_timestamps.shape)

# fmt: off
EXPECTED_OUTPUT = torch.tensor([
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6200, 0.7400, 0.8600, 1.0000, 1.0400, 1.3000, 1.4400, 1.7800, 2.1800, 2.2800, 2.5000, 2.9200, 3.0000, 3.3800, 3.5000, 3.6000, 3.8400, 4.1000, 4.4000, 4.6800, 5.1400, 5.3600, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200, 5.8200],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6000, 0.9200, 1.2200, 1.3400, 1.4200, 1.5400, 1.5800, 1.7400, 2.0600, 2.3800, 3.0400, 3.3800, 3.6400, 4.1200, 4.3600, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800, 4.7800],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5400, 0.8200, 1.1600, 1.4600, 1.7400, 1.8800, 2.3400, 2.7400, 3.1400, 3.2200, 3.5400, 4.2800, 4.5600, 4.8200, 5.0600, 5.3200, 5.6600, 5.9600, 6.1400, 6.4000, 6.8400, 7.8800, 8.0200, 8.3600, 8.7000, 9.0200, 9.3200, 9.5000, 9.8400, 10.3000, 10.6600, 11.0800, 11.3600, 11.4600, 11.8000, 12.4600],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5600, 0.7600, 1.0600, 1.4000, 1.8800, 2.2600, 2.6200, 2.8000, 2.9600, 3.0000, 3.2000, 3.4400, 3.6800, 4.0000, 4.6000, 5.0000, 5.3200, 5.4800, 6.0600, 6.0600, 6.1000, 6.3200, 6.7400, 7.0000, 7.2200, 7.4000, 7.7600, 8.0600, 8.5600, 8.8600, 8.9400, 9.1000, 9.3400, 9.8800, 9.8800, 9.8800]
])
# fmt: on

self.assertTrue(torch.allclose(generate_outputs.token_timestamps.to("cpu"), EXPECTED_OUTPUT))

@slow
def test_tiny_token_timestamp_batch_generation(self):
set_seed(0)
Expand Down

0 comments on commit 1c2bb3a

Please sign in to comment.