-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DRAFT] Token level timestamps with DTW (#375) #1485
[DRAFT] Token level timestamps with DTW (#375) #1485
Conversation
7a85b62
to
de75062
Compare
This is awesome. I think the big question is where the alignment heads actually reside in GGML. |
Yes, that is the most critical point at the moment. I suspect that we would need to save this tensor: // Inside whisper_build_graph_decoder
2448 // K * Q
2449 struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q); // THIS ONE
2450
2451 //struct ggml_tensor * KQ_scaled =
2452 // ggml_scale(ctx0,
2453 // KQ,
2454 // ggml_new_f32(ctx0, 1.0f/sqrt(float(n_state)/n_head))
2455 // );
2456
2457 // no masking for cross-attention
2458 //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
2459
2460 struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ); // OR THIS ONE In OpenAI implementation, they do a second pass through the model passing the tokens generated on the first pass to be able to retrieve the correct weights. I think if we are able to that second pass just like in openAI impl, it should just be a matter of saving those KQs and selecting the ones useful for timestamping. Apparently lintoai whisper-timestamped seems to be able to do in one pass, only when there is no temperature fallback, greedy bestof = 1 and no beam search. If any of those conditions are not met, then it resorts to doing a second pass like on OpenAI impl. In what concerns of selecting heads that are useful for timestamping, from what i have seen, a accepted default would be the heads of the top most half of decoder layers (the ones closer to model output). A more optimal selection, which is provided on OpenAIs whisper individually by model, was apparently obtained by manual inspection . Note that alignment heads differ even for models with same dimensions (e.g. medium and medium.en have different alignment heads specified) We could still use these indexes for the usual pre-trained models but offer the option to use the n top-most decoder layers OR some custom index i guess. |
de75062
to
3c9969e
Compare
Looking forward to the progress 🎉 |
Haven't been able to work with this past week, but making some progress now! Trying to get a very poorly implemented end-to-end POC for the Validated most of the pre-processing operations done on QKs before passing then to DTW to get timestamps. So, taking the OpenAI implementation on timing.py: (with added comments by me) # Implemented on whisper.cpp, seems ok
tokens = torch.tensor(
[
*tokenizer.sot_sequence,
tokenizer.no_timestamps,
*text_tokens,
tokenizer.eot,
]
).to(model.device)
# I believe that i am making a mistake somewhere in this block
# install hooks on the cross attention layers to retrieve the attention weights
QKs = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
)
for i, block in enumerate(model.decoder.blocks)
]
with torch.no_grad():
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
token_probs = sampled_logits.softmax(dim=-1)
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
text_token_probs = text_token_probs.tolist()
for hook in hooks:
hook.remove()
# Implemented poorly in whisper.cpp (only using base.en alignment heads for now)
# heads * tokens * frames
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
weights = weights[:, :, : num_frames // 2]
weights = (weights * qk_scale).softmax(dim=-1)
# Implemented and validated on whisper.cpp
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
weights = (weights - mean) / std
weights = median_filter(weights, medfilt_width)
matrix = weights.mean(axis=0)
matrix = matrix[len(tokenizer.sot_sequence) : -1]
text_indices, time_indices = dtw(-matrix) Now, to get the whisper.cpp QKs, i temporarily added a // Inside whisper_build_graph_decoder
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ); // THIS TENSOR HERE!
if (wstate.decoder_save_cross_QKs) {
wstate.cross_QKs.push_back(KQ_soft_max);
}
// After inference, when computing timestamps
state->decoder_save_cross_QKs = true;
if (whisper_decode_with_state(ctx, state, tokens.data(), tokens.size(), 0, n_threads) != 0) {
WHISPER_LOG_INFO("DECODER FAILED\n");
}
state->decoder_save_cross_QKs = false; Doing this, i get a set of QKs to work with inside my timestamping function, and the amount of tensors retrieved and their dimensions are in line with what's retrieved in the openAI impl. Now, since the output i got so far for timestamps is complete garbage, and most of the operations outside retrieval of QKs seem correct, i imagine something is wrong when i call Any insight into how to correctly retrieve these QKs from decoder cross attention layers is very welcome! I'm currently on working on validating all operations other than QK retrieval to be absolutely sure this is where my mistake is. |
Currently stuck on retrieving the attention weights from the decoder. In all my attempts, i either get tensors with floats > 1 (which indicates that they are not the output of the softmax layer i'm trying to retrieve), null tensors, or tensors with null "data" pointer. What i have tried so far (inside
// Inside whisper_build_graph_decoder
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
wstate.cross_QKs.push_back(KQ_soft_max);
// Later on, e.g. on timestamping function
// Many values are > 1, indicating that they are not the output of softmax
float v = ggml_get_f32_nd(state.cross_QKs[0], 0, 0, 0, 0);
// Inside whisper_build_graph_decoder
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
struct ggml_tensor * KQ_copy = ggml_cpy(ctx0, KQ_soft_max, ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, KQ_soft_max->ne[0], KQ_soft_max->ne[1], KQ_soft_max->ne[2]));
wstate.cross_QKs.push_back(KQ_copy);
// Later on, e.g. on timestamping function
if (state.cross_QKs[0]->data == NULL)
// this is true
// Inside whisper_build_graph_decoder
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ);
struct ggml_tensor * KQ_copy = ggml_cpy(ctx0, KQ_soft_max, ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, KQ_soft_max->ne[0], KQ_soft_max->ne[1], KQ_soft_max->ne[2]));
char name[256];
snprintf(name, 256, "cross_QK_%d", il);
ggml_set_name(KQ_soft_max, name);
snprintf(name, 256, "cross_QK_copy_%d", il);
ggml_set_name(KQ_copy, name);
// Later on, right after ggml_graph_compute_helper inside whisper_decode_internal
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
struct ggml_tensor * cross_QK = ggml_graph_get_tensor(gf, "cross_QK_0");
struct ggml_tensor * cross_copy = ggml_graph_get_tensor(gf, "cross_QK_copy_0");
float v = ggml_get_f32_nd(cross_QK, 0, 0, 0, 0) // v > 1, not softmax output
if (cross_copy == NULL)
// this is true @ggerganov would you be able to nudge me in right direction in how would i go about saving the values of the |
So apparently i was missing the fact that i had to call With that done, finally got a correct output for timestamps. Running over the
Which is identical to the timestamps retrieved on openAI impl before their heuristics to determine start/end of each token:
I'll be working on correctly plumbing everything together, since i poorly stitched everything together just to see end-to-end execution. |
Is this committed onto your fork? I would love to try it out. I have some audio that I consider to be a challenge, that the OpenAI Python one handles. I would love to compare it and test it. |
Happy to help with this if your able to commit your latest updates |
I'll try to commit the changes between today and tomorrow so you guys can give it spin! |
fbd390d
to
ab5cd86
Compare
I've committed these recent results into the fork. Currently, the timestamps are only being placed if you run with To run with DTW timestamps, you need to enable then in params and select a collection of alignment heads. I've imported those on OpenAI impl for each model, but have yet to implement setting custom alignment heads or using the N top most heads for alignment. So, e.g. for diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 9699802..c101434 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -942,6 +942,10 @@ int main(int argc, char ** argv) {
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
+ wparams.dtw_token_timestamps = true;
+ wparams.dtw_ah_preset = WHISPER_AHEADS_BASE_EN; // Match correctly with the model you are using.
+ wparams.single_segment = true;
+
wparams.print_realtime = false;
wparams.print_progress = params.print_progress;
wparams.print_timestamps = !params.no_timestamps; and timestamps should be placed on the auto token_data = whisper_full_get_token_data(ctx, i_segment, i_token);
float timestamp_seconds = (float)token_data.t_dtw/100; I've left several FIXMEs on I've only tested so far on |
So, i've updated the TODO list and pushed some fixes:
Any help in the current TODO list is very welcome, specially the last 3 items. I've made some tests with the |
936be88
to
f8d1b1f
Compare
Just an update, haven't been able to work the last few issues lately and probably won't be completing in December. Nevertheless, I've been doing some tests with whats implemented so far and timestamps seem to be working as expected. About the last couple of issues, what I'm having most trouble figuring it out is "Avoid memory allocations on whisper_exp_compute_token_level_timestamps_dtw". I'm not very confident on how to retrieve the total memory needed in each case, since some operations are done "manually" over ggml tensors instead of using a ggml graph. Also, the memory needed can vary greatly - like from <16MB to >2GB - depending on factors such as how many tokens the model yielded this run, audio_ctx size and the number of alignment heads (which is limited by numbers of attention heads in total). This can all be taken into account when allocating to try to allocate the minimum necessary for the worst case. Any help on these final issues would be very appreciated! |
I've been running some tests on this today. It's not perfect but it's definitely way way better. I have a test video that I've been using that has a lot of easy to find issues that let me compare the released version vs this branch (and against the python implementation). It starts with a 12 second silent section, and has a few silent sections peppered throughout. (Link below). It should be identifying the first word as starting at about 12.4 seconds in, but its identifying it as 7.6 seconds in. Thats a lot better than the original that was placing it at 0 seconds in. The second and third token seem to be correctly identified tho. Overall, this is incredible work! I'm going to submit a PR here shortly that should expose this a little easier in the main.cpp and go bindings. Is there a model auto-detect already? I'm not seeing it here but I've only just barely started looking at the actual code. Queen 'We will meet again' speech. |
Hey @bmurray, thanks for testing! I'll try to run this audio here as well. May i ask which model size and alignment head preset were you using? About that first token with wrong timestamp, i will try to test it on openAI impl as well to see their DTW output. It might be the case that they get the same output and fix that on the subsequent heuristics after running DTW. If you check their implementation, what we have implemented stops about here. After returning from that function, they still do some stuff that seems to be especially aimed to improving timestamps on bounds About model auto-detect, I'm not sure how we can implement that. Since alignment heads are different for models with same size (e.g. large-v1 and large-v2 have different alignment heads), we might need user input to guarantee we are using the correct alignment heads. Not sure if I'm mistaken though. Just a PS, I'll be out after this week and probably return to this only start/mid January, but I'll happily address any comments when i return. |
// dtw | ||
// supposedly can be optmized by computing diagonals in parallel ? | ||
// Not sure it is worth it since x will be GENERATED_TOKENS*1500 size at most. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How long does this step take?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, i've just made some measurements and it is quite rather insignificant compared to the rest of the function. On a Apple M2 and metal enabled, less than 1% of the whole time-stamping process. (Moved longer analysis to MR comments)
f8d1b1f
to
e99b9e5
Compare
add2db7
to
10b0304
Compare
The DTW timestamps can now be generated with ./main -m models/ggml-small.bin -f samples/gb0.wav -dtw small -ojf {
"timestamps": {
"from": "00:02:02,960",
"to": "00:02:05,520"
},
"offsets": {
"from": 122960,
"to": 125520
},
"text": " by the will of the people.",
"tokens": [
{
"text": " by",
"timestamps": {
"from": "00:02:02,960",
"to": "00:02:03,180"
},
"offsets": {
"from": 122960,
"to": 123180
},
"id": 538,
"p": 0.999897,
"t_dtw": 12312
},
{
"text": " the",
"timestamps": {
"from": "00:02:03,180",
"to": "00:02:03,510"
},
"offsets": {
"from": 123180,
"to": 123510
},
"id": 264,
"p": 0.999729,
"t_dtw": 12328
},
{
"text": " will",
"timestamps": {
"from": "00:02:03,510",
"to": "00:02:03,590"
},
"offsets": {
"from": 123510,
"to": 123590
},
"id": 486,
"p": 0.997792,
"t_dtw": 12354
},
{
"text": " of",
"timestamps": {
"from": "00:02:04,140",
"to": "00:02:04,170"
},
"offsets": {
"from": 124140,
"to": 124170
},
"id": 295,
"p": 0.999649,
"t_dtw": 12430
},
{
"text": " the",
"timestamps": {
"from": "00:02:04,170",
"to": "00:02:04,500"
},
"offsets": {
"from": 124170,
"to": 124500
},
"id": 264,
"p": 0.999611,
"t_dtw": 12440
},
{
"text": " people",
"timestamps": {
"from": "00:02:04,500",
"to": "00:02:05,090"
},
"offsets": {
"from": 124500,
"to": 125090
},
"id": 561,
"p": 0.999641,
"t_dtw": 12482
},
{
"text": ".",
"timestamps": {
"from": "00:02:05,200",
"to": "00:02:05,440"
},
"offsets": {
"from": 125200,
"to": 125440
},
"id": 13,
"p": 0.998121,
"t_dtw": 12512
},
{
"text": "[_TT_416]",
"timestamps": {
"from": "00:02:05,520",
"to": "00:02:05,520"
},
"offsets": {
"from": 125520,
"to": 125520
},
"id": 50780,
"p": 0.280601,
"t_dtw": -1
}
]
}, |
* whisper.cpp: impl dtw algo * WIP: producing and placing DTW timestamps on tokens * Fix compile and assertion errors. Attempt to DTW timestamp with single_segment=false. * Fix mistake causing incorrect alignment of dtw timestamps * implement N_TOP_MOST and CUSTOM alignment heads setting * whisper: fix typo on alignment heads enum * Fix issues related to changes in whisper.cpp * Fixed excessive memory use when using DTW timestamps. Other minor fixes to DTW timestamping function * decoder: save cross QKs only if requested * Calling median filter with ggml_map_custom1 * Reimpl aheads n_top_most and custom. Sanity checks on chosen aheads * Copying cross QKs from decoder backend correctly * dtw: cleanup * Fix incorrect n_frames passed to dtw when near end of audio * Fix aheads_masks_init for backend != CPU * whisper : minor style * main : add dtw (wip) * whisper: fix invalid memory access in aheads_masks_init * main : add dtw (cont) * whisper : minor --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
I am looking to derive word level timestamps from DTW timestamps. Some tokens need to be joined to form a complete word, so as a start I figured to just concatenate tokens until a space or punctuation mark is encountered, and from just accumulate the timestamps from the tokens to determine the start and end times for each word. I am pretty sure there must be way more to consider. @denersc it would be great if you can comment on this. - and hey super thanks for your work on the DTW timestamps. |
not sure if anyone has noticed this too but the DTW timestamps seem to be completely inaccurate in some segments of the transcript. while other segments, it's very precise. |
I was experimenting yesterday and saw forward shifts in token timestamps prior to silence periods. (Say pauses in speech for 1-2 sec). In my tests I did have a good bit of background noise that might have been a factor. I will try to redo some experiments in the next few days. |
Hey @hlevring and @eschmidbauer, thanks for trying it out! So, I'll try to address what you guys said, but unfortunately i don't think i can provide perfect answers. First, when thinking about DTW timestamps, i crudely rationalize about them as "A estimate of the moment the model decided to output a certain token". So, in common speech flow, like between equal paced words in a sentence, it is very likely that the DTW timestamp of the last token in a word will be very close to the actual time of the end of the word. Nevertheless, it is not unusual that the model will output some token long after it actually occurred in audio. In that case, DTW timestamp will likely be incorrect. The most common example of this is the period (.) token. It can outputted by the model after some time of silence preceded by a sequence of words. So the DTW timestamp for the period will be long after sentence end, and should probably be ignored. Although period and punctuation in general are the most common occurrence of this, I don't doubt this kind of thing can happen with words, e.g. model outputting 3 words almost simultaneously because only when it understood the third word could it actually infer all three. In that case, the first 2 words will have very imprecise timestamps. Although i think this may be possible, it does not seem to be very likely, at least in the sense that i haven't observed it directly. All of this to say, DTW timestamps are a imperfect source of information, and should be used with some caution and combined with other data to provide good word timestamp estimates. OpenAI tries to address some of theses issues Of course, it may also be the case that my implementation is incorrect on some point. Maybe on the step of selecting and saving alignment heads or when performing some matrix operations. I think a good starting point to check that would be to compare the DTW timestamps given by openAI impl with the ones in my implementation for a variety of audios and see if there are any large discrepancies. Some very small variance is bound to happen probably because of different matrix operation implementations. Finally, I'm more on the developer side than on the ML research side, my math understanding is reasonably shallow. So take all i said with a grain of salt 😬 |
This is a great idea! I'll set some time aside to compare the two & post my findings |
Also, forgot to say, make sure whisper.cpp version you guys are using is after MR #2012. That bug likely caused incorrect alignment head selection, beyond the observed memory error. Cool @eschmidbauer. On whisper.cpp, you can uncomment these lines if you want to print DTW timestamps. You might need to change the code on the OpenAI package to get the actual raw DTW timestamps, since they don't provide them to the final user. They do a lot of additional processing before giving it back, so those will be different for sure. You'll probably have to add some code to retrieve these and do some sort of loop equivalent to what i did here to get timestamps for each token. These will be the raw timestamps which are comparable to the ones i made available on whisper.cpp. |
* whisper.cpp: impl dtw algo * WIP: producing and placing DTW timestamps on tokens * Fix compile and assertion errors. Attempt to DTW timestamp with single_segment=false. * Fix mistake causing incorrect alignment of dtw timestamps * implement N_TOP_MOST and CUSTOM alignment heads setting * whisper: fix typo on alignment heads enum * Fix issues related to changes in whisper.cpp * Fixed excessive memory use when using DTW timestamps. Other minor fixes to DTW timestamping function * decoder: save cross QKs only if requested * Calling median filter with ggml_map_custom1 * Reimpl aheads n_top_most and custom. Sanity checks on chosen aheads * Copying cross QKs from decoder backend correctly * dtw: cleanup * Fix incorrect n_frames passed to dtw when near end of audio * Fix aheads_masks_init for backend != CPU * whisper : minor style * main : add dtw (wip) * whisper: fix invalid memory access in aheads_masks_init * main : add dtw (cont) * whisper : minor --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
* whisper.cpp: impl dtw algo * WIP: producing and placing DTW timestamps on tokens * Fix compile and assertion errors. Attempt to DTW timestamp with single_segment=false. * Fix mistake causing incorrect alignment of dtw timestamps * implement N_TOP_MOST and CUSTOM alignment heads setting * whisper: fix typo on alignment heads enum * Fix issues related to changes in whisper.cpp * Fixed excessive memory use when using DTW timestamps. Other minor fixes to DTW timestamping function * decoder: save cross QKs only if requested * Calling median filter with ggml_map_custom1 * Reimpl aheads n_top_most and custom. Sanity checks on chosen aheads * Copying cross QKs from decoder backend correctly * dtw: cleanup * Fix incorrect n_frames passed to dtw when near end of audio * Fix aheads_masks_init for backend != CPU * whisper : minor style * main : add dtw (wip) * whisper: fix invalid memory access in aheads_masks_init * main : add dtw (cont) * whisper : minor --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
* whisper.cpp: impl dtw algo * WIP: producing and placing DTW timestamps on tokens * Fix compile and assertion errors. Attempt to DTW timestamp with single_segment=false. * Fix mistake causing incorrect alignment of dtw timestamps * implement N_TOP_MOST and CUSTOM alignment heads setting * whisper: fix typo on alignment heads enum * Fix issues related to changes in whisper.cpp * Fixed excessive memory use when using DTW timestamps. Other minor fixes to DTW timestamping function * decoder: save cross QKs only if requested * Calling median filter with ggml_map_custom1 * Reimpl aheads n_top_most and custom. Sanity checks on chosen aheads * Copying cross QKs from decoder backend correctly * dtw: cleanup * Fix incorrect n_frames passed to dtw when near end of audio * Fix aheads_masks_init for backend != CPU * whisper : minor style * main : add dtw (wip) * whisper: fix invalid memory access in aheads_masks_init * main : add dtw (cont) * whisper : minor --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Tries to solve #375
Attempt to implement DTW-based token level timestamps, as seen in OpenAI Whisper.
This first commit implements the DTW algorithm on whisper.cpp and provides tests that compare the output with the output of OpenAI's implementation. Tests are done calling whisper.cpp from Python and comparing DTW output with OpenAI's
dtw_cpu
function.An outline of remaining work is commented on
whisper_exp_compute_token_level_timestamps_dtw
in whisper.cpp. Help/insights are very appreciated, specially concerning how to cache/retrieve the output of MHA layers that are used as input for DTW.In OpenAI's implementation, token-level timestamps are used with further heuristics to determine a supposed start/end time for words. In this PR, my intention is to implement token-level only as a first step that can be used to implement word timestamps in the future.
TODO
the output of the MHA layersQKs from cross-attention layers from alignment heads (perhaps in whisper_state?) and retrieve them inwhisper_exp_compute_token_level_timestamps_dtw
whisper_exp_compute_token_level_timestamps_dtw
whisper_exp_compute_token_level_timestamps_dtw
intowhisper_full
and use results to place timestamps on each inferred tokenN_TOP_MOST
alignment headsCUSTOM
alignment heads, decide comfortable API for setting custom alignment headswhisper_build_graph_decoder
to only save QK copies if requested, so there is no additional overhead when running decoder for other reasons than timestamps.whisper_exp_compute_token_level_timestamps_dtw
(probably allocate a buffer for the used tensors on init)whisper_exp_compute_token_level_timestamps_dtw
that are currently done with manual for loops can benefit from ggml functions.