From bf4cb4abad4e35c74b387df034cc4ac7b22e5fe6 Mon Sep 17 00:00:00 2001 From: mky_coder <47767389+mkycoder@users.noreply.github.com> Date: Tue, 18 Jun 2024 23:10:33 +0800 Subject: [PATCH] whisper : optimize fft() function (#2242) Co-authored-by: Mike Fan <60965742+mike-fzy@users.noreply.github.com> --- whisper.cpp | 56 +++++++++++++++++++++-------------------------------- 1 file changed, 22 insertions(+), 34 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 4b96e8bcb66..d10083ec97b 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2974,10 +2974,7 @@ whisper_span whisper_mel_calc::hann_window() { // naive Discrete Fourier Transform // input is real-valued // output is complex-valued -static void dft(const std::vector & in, std::vector & out) { - int N = in.size(); - - out.resize(N*2); +static void dft(const float* in, int N, float* out) { const int sin_cos_step = SIN_COS_N_COUNT / N; for (int k = 0; k < N; k++) { @@ -2999,44 +2996,35 @@ static void dft(const std::vector & in, std::vector & out) { // poor man's implementation - use something better // input is real-valued // output is complex-valued -static void fft(const std::vector & in, std::vector & out) { - out.resize(in.size()*2); - - int N = in.size(); - +static void fft(float* in, int N, float* out) { if (N == 1) { out[0] = in[0]; out[1] = 0; return; } - if (N%2 == 1) { - dft(in, out); + const int half_N = N / 2; + if (N - half_N*2 == 1) { + dft(in, N, out); return; } - std::vector even; - std::vector odd; - - even.reserve(N/2); - odd.reserve(N/2); - - for (int i = 0; i < N; i++) { - if (i % 2 == 0) { - even.push_back(in[i]); - } else { - odd.push_back(in[i]); - } + float* even = in + N; + for (int i = 0; i < half_N; ++i) { + even[i]= in[2*i]; } + float* even_fft = out + 2 * N; + fft(even, half_N, even_fft); - std::vector even_fft; - std::vector odd_fft; - - fft(even, even_fft); - fft(odd, odd_fft); + float* odd = even; + for (int i = 0; i < half_N; ++i) { + odd[i] = in[2*i + 1]; + } + float* odd_fft = even_fft + N; + fft(odd, half_N, odd_fft); const int sin_cos_step = SIN_COS_N_COUNT / N; - for (int k = 0; k < N/2; k++) { + for (int k = 0; k < half_N; k++) { int idx = k * sin_cos_step; // t = 2*M_PI*k/N float re = global_cache.cos_vals[idx]; // cos(t) float im = -global_cache.sin_vals[idx]; // sin(t) @@ -3047,8 +3035,8 @@ static void fft(const std::vector & in, std::vector & out) { out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd; out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd; - out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; - out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; + out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; + out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; } } @@ -3066,8 +3054,8 @@ void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::v const whisper_filters & filters, whisper_mel_data & mel) { const auto frame_size = WHISPER_N_FFT; const auto frame_step = WHISPER_HOP_LENGTH; - std::vector fft_in(frame_size, 0.0); - std::vector fft_out(2 * frame_size); + std::vector fft_in(frame_size * 2, 0.0); + std::vector fft_out(frame_size * 2 * 2 * 2); int n_fft = filters.n_fft; int i = ith; @@ -3088,7 +3076,7 @@ void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::v } // FFT - fft(fft_in, fft_out); + fft(fft_in.data(), frame_size, fft_out.data()); // Calculate modulus^2 of complex numbers // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.