Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[whisper] alternative fix for long-form timestamps #32131

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,11 +587,20 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1)

last_slice = np.where(timestamp_tokens)[0][0]
cur_max_timestamp = 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With no change to the overall processor/tokenizer design, we can fix the original timestamp issue by keeping track of the last timestamp predicted

prev_segments_len = 0
for current_slice in consecutive:
sliced_tokens = token_ids[last_slice:current_slice]
if len(sliced_tokens) > 1:
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin

if start_timestamp_position < cur_max_timestamp:
# next segment has started
prev_segments_len += cur_max_timestamp

cur_max_timestamp = end_timestamp_position

# strip timestamp tokens from the text output
sliced_tokens = self._preprocess_token_ids(sliced_tokens)
text = self._decode(sliced_tokens)
Expand All @@ -600,8 +609,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
{
"text": text,
"timestamp": (
start_timestamp_position * time_precision,
end_timestamp_position * time_precision,
(start_timestamp_position + prev_segments_len) * time_precision,
(end_timestamp_position + prev_segments_len) * time_precision,
),
}
)
Expand Down
13 changes: 11 additions & 2 deletions src/transformers/models/whisper/tokenization_whisper_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,20 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1)

last_slice = np.where(timestamp_tokens)[0][0]
cur_max_timestamp = 0
prev_segments_len = 0
for current_slice in consecutive:
sliced_tokens = token_ids[last_slice:current_slice]
if len(sliced_tokens) > 1:
start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin

if start_timestamp_position < cur_max_timestamp:
# next segment has started
prev_segments_len += cur_max_timestamp

cur_max_timestamp = end_timestamp_position

# strip timestamp tokens from the text output
sliced_tokens = self._preprocess_token_ids(sliced_tokens)
text = self._decode(sliced_tokens)
Expand All @@ -242,8 +251,8 @@ def _compute_offsets(self, token_ids, time_precision=0.02):
{
"text": text,
"timestamp": (
start_timestamp_position * time_precision,
end_timestamp_position * time_precision,
(start_timestamp_position + prev_segments_len) * time_precision,
(end_timestamp_position + prev_segments_len) * time_precision,
),
}
)
Expand Down
59 changes: 59 additions & 0 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2243,6 +2243,65 @@ def test_tiny_timestamp_generation(self):
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript, EXPECTED_TRANSCRIPT)

@slow
def test_tiny_longform_timestamps_generation(self):
set_seed(0)
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.to(torch_device)

dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]

input_features = processor(
sample["array"], return_tensors="pt", truncation=False, sampling_rate=sample["sampling_rate"]
)
input_features = input_features.to(torch_device)

generated_ids = model.generate(**input_features, return_timestamps=True, return_segments=True)

EXPECTED_TRANSCRIPT = [
{
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
"timestamp": (0.0, 6.5600000000000005),
},
{
"text": " Nor is Mr. Quilter's manner less interesting than his matter.",
"timestamp": (6.5600000000000005, 11.24),
},
{
"text": " He tells us that at this festive season of the year, with Christmas and roast beef looming",
"timestamp": (11.24, 16.88),
},
{
"text": " before us, similarly drawn from eating and its results occur most readily to the mind.",
"timestamp": (16.88, 23.76),
},
{
"text": " He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and",
"timestamp": (23.76, 29.44),
},
{"text": " can discover in it but little of rocky ithaka.", "timestamp": (29.44, 33.72)},
{
"text": " Lennils, pictures, are a sort of upguards and atom paintings, and Mason's exquisite itals",
"timestamp": (33.72, 40.32),
},
{"text": " are as national as a jingo poem.", "timestamp": (40.32, 44.72)},
{
"text": " Mr. Birkut Foster's landscapes smile at one much in the same way that Mr. Carker used",
"timestamp": (44.72, 50.4),
},
{"text": " to flash his teeth.", "timestamp": (50.4, 52.96)},
{
"text": " And Mr. John Collier gives his sitter a cheerful slap on the back before he says, like",
"timestamp": (52.96, 58.68),
},
{"text": " a shampoo and a Turkish bath next man.", "timestamp": (58.68, 61.96)},
]

transcript = processor.batch_decode(generated_ids["sequences"], skip_special_tokens=True, output_offsets=True)
self.assertEqual(transcript[0]["offsets"], EXPECTED_TRANSCRIPT)

@slow
def test_large_timestamp_generation(self):
set_seed(0)
Expand Down
Loading