Skip to content

Commit

Permalink
Fix issues related to changes in whisper.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
denersc committed Jan 11, 2024
1 parent d6f4d7a commit e99b9e5
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6685,7 +6685,7 @@ static void whisper_exp_compute_token_level_timestamps(
// based on
// https://github.com/openai/whisper/blob/main/whisper/timing.py#L83
static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
WHISPER_ASSERT(x->n_dims == 2);
WHISPER_ASSERT(ggml_n_dims(x) == 2);

int64_t N = x->ne[0];
int64_t M = x->ne[1];
Expand Down Expand Up @@ -6773,7 +6773,7 @@ static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) {
static ggml_tensor * median_filter(ggml_context * ctx, ggml_tensor * x, int filter_width) {
WHISPER_ASSERT(filter_width < x->ne[2]);
WHISPER_ASSERT(filter_width % 2);
WHISPER_ASSERT(x->n_dims == 3);
WHISPER_ASSERT(ggml_n_dims(x) == 3);
WHISPER_ASSERT(x->type == GGML_TYPE_F32);

std::vector<float> filter;
Expand Down Expand Up @@ -6918,7 +6918,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
// operation (after median filter)
// IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims
// OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
w = ggml_norm(gctx, w, 0);
w = ggml_norm(gctx, w, 1e-9);
w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3);
struct ggml_cgraph * gf = ggml_new_graph(gctx);
ggml_build_forward_expand(gf, w);
Expand All @@ -6933,9 +6933,7 @@ static void whisper_exp_compute_token_level_timestamps_dtw(
// IN: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims
// OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS dims
w = ggml_mean(gctx, w);
ggml_tensor * scale = ggml_new_tensor_1d(gctx, GGML_TYPE_F32, 1);
ggml_set_f32_1d(scale, 0, -1);
w = ggml_scale(gctx, w, scale);
w = ggml_scale(gctx, w, -1.0);
w = ggml_reshape_2d(gctx, w, w->ne[1], w->ne[2]);
struct ggml_cgraph * gf2 = ggml_new_graph(gctx);
ggml_build_forward_expand(gf2, w);
Expand Down

0 comments on commit e99b9e5

Please sign in to comment.