Skip to content

Commit

Permalink
implement N_TOP_MOST and CUSTOM alignment heads setting
Browse files Browse the repository at this point in the history
  • Loading branch information
denersc committed Dec 11, 2023
1 parent 619f0a8 commit 4f76929
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 26 deletions.
78 changes: 53 additions & 25 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4421,7 +4421,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str

/*.dtw_token_timestamps =*/ false,
/*.dtw_ah_preset =*/ WHISPER_AHEADS_NONE,
/*.dtw_n_stop_most =*/ {
/*.dtw_n_top_most =*/ {
/*.n =*/ -1,
},
/*.dtw_custom =*/ {
Expand Down Expand Up @@ -5890,7 +5890,6 @@ int whisper_full_with_state(

int n_new = 1;


if (params.token_timestamps) {
whisper_exp_compute_token_level_timestamps(
*ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
Expand Down Expand Up @@ -6723,7 +6722,7 @@ static void whisper_exp_compute_token_level_timestamps(
// dtw + backtrace to return found path
// based on
// https://github.com/openai/whisper/blob/main/whisper/timing.py#L83
static ggml_tensor * dtw_and_backtrace(ggml_context *ctx, ggml_tensor *x) {
static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
WHISPER_ASSERT(x->n_dims == 2);

int64_t N = x->ne[0];
Expand Down Expand Up @@ -6809,7 +6808,7 @@ static ggml_tensor * dtw_and_backtrace(ggml_context *ctx, ggml_tensor *x) {
return r;
}

static ggml_tensor * median_filter(ggml_context *ctx, ggml_tensor *x, int filter_width) {
static ggml_tensor * median_filter(ggml_context * ctx, ggml_tensor * x, int filter_width) {
WHISPER_ASSERT(filter_width < x->ne[2]);
WHISPER_ASSERT(filter_width % 2);
WHISPER_ASSERT(x->n_dims == 3);
Expand Down Expand Up @@ -6843,6 +6842,54 @@ static ggml_tensor * median_filter(ggml_context *ctx, ggml_tensor *x, int filter
return r;
}

static ggml_tensor * get_alignment_heads_QKs(
ggml_context * ctx,
struct whisper_state * state,
struct whisper_full_params params,
int n_audio_tokens)
{
const auto n_text_layers = (int) state->cross_QKs.size();
const auto heads_per_layer = state->cross_QKs[0]->ne[2];
const auto n_tokens = state->cross_QKs[0]->ne[1];

if (params.dtw_ah_preset == WHISPER_AHEADS_N_TOP_MOST) {
WHISPER_ASSERT(params.dtw_n_top_most.n <= n_text_layers);
const auto n_heads = heads_per_layer * params.dtw_n_top_most.n;

// FIXME: manually stacking + clipping + permuting might not be the most efficient way? (e.g. use ggml funcs)
ggml_tensor * w = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_tokens, n_audio_tokens, n_heads);
for (int k = 0; k < n_heads; ++k) {
for (int i = 0; i < n_audio_tokens; ++i) {
for (int j = 0; j < state->cross_QKs[0]->ne[1]; ++j) {
auto text_layer = n_text_layers - (k / heads_per_layer) - 1;
auto head = k % heads_per_layer;
const float v = ggml_get_f32_nd(state->cross_QKs[text_layer], i, j, head, 0);
ggml_set_f32_nd(w, j, i, k, 0, v);
}
}
}
return w;

} else {
const auto alignment_heads = params.dtw_ah_preset == WHISPER_AHEADS_CUSTOM ? params.dtw_custom.aheads : g_aheads.at(params.dtw_ah_preset);
ggml_tensor * w = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, state->cross_QKs[0]->ne[1], n_audio_tokens, alignment_heads.n_heads);

// FIXME: manually stacking + clipping + permuting might not be the most efficient way? (e.g. use ggml funcs)
for (size_t k = 0; k < alignment_heads.n_heads; ++k) {
for (int i = 0; i < n_audio_tokens; ++i) {
for (int j = 0; j < state->cross_QKs[0]->ne[1]; ++j) {
auto text_layer = alignment_heads.heads[k].n_text_layer;
auto head = alignment_heads.heads[k].n_head;
const float v = ggml_get_f32_nd(state->cross_QKs[text_layer], i, j, head, 0);
ggml_set_f32_nd(w, j, i, k, 0, v);
}
}
}
return w;
}
}


static void whisper_exp_compute_token_level_timestamps_dtw(
struct whisper_context * ctx,
struct whisper_state * state,
Expand All @@ -6858,17 +6905,11 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
WHISPER_ASSERT(n_frames <= ctx->model.hparams.n_audio_ctx * 2);
WHISPER_ASSERT(params.dtw_ah_preset != WHISPER_AHEADS_NONE);

// unimplemented
WHISPER_ASSERT(params.dtw_ah_preset != WHISPER_AHEADS_N_TOP_MOST);
WHISPER_ASSERT(params.dtw_ah_preset != WHISPER_AHEADS_CUSTOM);

const auto alignment_heads = g_aheads.at(params.dtw_ah_preset);

// FIXME: Allocating mem everytime we call this func
// Our ggml buffer should be pre-allocated somewhere during init and reused
// when we call this function
struct ggml_init_params gparams = {
/*.mem_size =*/ 16*1024*1024,
/*.mem_size =*/ 32*1024*1024,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ false,
};
Expand Down Expand Up @@ -6902,25 +6943,12 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
WHISPER_ASSERT(0);
}

// FIXME: manually stacking + clipping + permuting might not be the most efficient way? (e.g. use ggml funcs)
// Stack alignment heads + clip unused audio tokens
// We permute dimensions so we can compute normalization on next step
// IN: N_TEXT_LAYERS tensors with audio_ctx*N_TOKENS*N_HEADS dims
// OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
const auto n_audio_tokens = n_frames/2;
//fprintf(stderr, "n_audio_tokens is %d\n", n_audio_tokens);
ggml_tensor * w = ggml_new_tensor_3d(gctx, GGML_TYPE_F32, state->cross_QKs[0]->ne[1], n_audio_tokens, alignment_heads.n_heads);
for (size_t k = 0; k < alignment_heads.n_heads; k++) {
for (int i = 0; i < n_audio_tokens; ++i) {
for (int j = 0; j < state->cross_QKs[0]->ne[1]; ++j) {
auto text_layer = alignment_heads.heads[k].n_text_layer;
auto head = alignment_heads.heads[k].n_head;
const float v = ggml_get_f32_nd(state->cross_QKs[text_layer], i, j, head, 0);
ggml_set_f32_nd(w, j, i, k, 0, v);
}
}
}
//fprintf(stderr, "weights has ne0 %ld ne1 %ld ne2 %ld ne3 %ld\n", w->ne[0], w->ne[1], w->ne[2], w->ne[3]);
ggml_tensor * w = get_alignment_heads_QKs(gctx, state, params, n_audio_tokens);

// Normalize - in original OpenAI code, this is done over dim=-2. In this case,
// we already permuted N_TOKENS dimension to rows on last loop, becase ggml_norm
Expand Down
2 changes: 1 addition & 1 deletion whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ extern "C" {

enum whisper_alignment_heads_preset {
WHISPER_AHEADS_NONE,
WHISPER_AHEADS_N_TOP_MOST,
WHISPER_AHEADS_N_TOP_MOST, // All heads from the N-top-most text-layers
WHISPER_AHEADS_CUSTOM,
WHISPER_AHEADS_TINY_EN,
WHISPER_AHEADS_TINY,
Expand Down

0 comments on commit 4f76929

Please sign in to comment.