Skip to content

Commit

Permalink
main : add dtw (cont)
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Mar 20, 2024
1 parent 59c133a commit b55925d
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 37 deletions.
48 changes: 34 additions & 14 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ void replace_all(std::string & s, const std::string & search, const std::string

// command-line parameters
struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_processors = 1;
int32_t offset_t_ms = 0;
int32_t offset_n = 0;
int32_t duration_ms = 0;
int32_t progress_step = 5;
int32_t max_context = -1;
int32_t max_len = 0;
int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
int32_t audio_ctx = 0;
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
int32_t n_processors = 1;
int32_t offset_t_ms = 0;
int32_t offset_n = 0;
int32_t duration_ms = 0;
int32_t progress_step = 5;
int32_t max_context = -1;
int32_t max_len = 0;
int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
int32_t audio_ctx = 0;

float word_thold = 0.01f;
float entropy_thold = 2.40f;
Expand Down Expand Up @@ -76,6 +76,8 @@ struct whisper_params {

std::string openvino_encode_device = "CPU";

std::string dtw = "";

std::vector<std::string> fname_inp = {};
std::vector<std::string> fname_out = {};
};
Expand Down Expand Up @@ -149,6 +151,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else {
Expand Down Expand Up @@ -208,6 +211,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", "");
fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str());
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
fprintf(stderr, "\n");
Expand Down Expand Up @@ -890,10 +894,26 @@ int main(int argc, char ** argv) {
struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu;

// TODO: expose these parameters to the command-line
if (false) {
if (!params.dtw.empty()) {
cparams.dtw_token_timestamps = true;
cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE_EN; // Match correctly with the model you are using.
cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE;

if (params.dtw == "tiny") cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY;
if (params.dtw == "tiny.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY_EN;
if (params.dtw == "base") cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE;
if (params.dtw == "base.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE_EN;
if (params.dtw == "small") cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL;
if (params.dtw == "small.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL_EN;
if (params.dtw == "medium") cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM;
if (params.dtw == "medium.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM_EN;
if (params.dtw == "large.v1") cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V1;
if (params.dtw == "large.v2") cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V2;
if (params.dtw == "large.v3") cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3;

if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
fprintf(stderr, "error: unknown DTW preset '%s'\n", params.dtw.c_str());
return 3;
}
}

struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
Expand Down
26 changes: 11 additions & 15 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1083,19 +1083,19 @@ static bool aheads_masks_init(
WHISPER_LOG_ERROR("%s: dtw_aheads_preset should be != DTW_AHEADS_NONE\n", __func__);
return false;
} else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) {
if (cparams.dtw_n_top_most.n > n_text_layer || cparams.dtw_n_top_most.n <= 0) {
WHISPER_LOG_ERROR("%s: dtw_n_top_most.n must be between %d and %d for this model.", __func__, 1, n_text_layer);
if (cparams.dtw_n_top > n_text_layer || cparams.dtw_n_top <= 0) {
WHISPER_LOG_ERROR("%s: dtw_n_top must be between %d and %d for this model.", __func__, 1, n_text_layer);
return false;
}
} else {
const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_custom.aheads : g_aheads.at(cparams.dtw_aheads_preset);
const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset);
if (cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM) {
if (aheads.n_heads == 0) {
WHISPER_LOG_ERROR("%s: dtw_custom.aheads.n_heads should be > 0", __func__);
WHISPER_LOG_ERROR("%s: dtw_aheads.n_heads should be > 0", __func__);
return false;
}
if (aheads.heads == NULL) {
WHISPER_LOG_ERROR("%s: dtw_custom.aheads.heads unset", __func__);
WHISPER_LOG_ERROR("%s: dtw_aheads.heads unset", __func__);
return false;
}
}
Expand Down Expand Up @@ -3374,14 +3374,10 @@ struct whisper_context_params whisper_context_default_params() {

/*.dtw_token_timestamps =*/ false,
/*.dtw_aheads_preset =*/ WHISPER_AHEADS_NONE,
/*.dtw_n_top_most =*/ {
/*.n =*/ -1,
},
/*.dtw_custom =*/ {
/*.aheads =*/ {
/*.n_heads =*/ 0,
/*.heads =*/ NULL,
}
/*.dtw_n_top =*/ -1,
/*.dtw_aheads =*/ {
/*.n_heads =*/ 0,
/*.heads =*/ NULL,
},
/*.dtw_mem_size =*/ 1024*1024*128,
};
Expand Down Expand Up @@ -6861,13 +6857,13 @@ static std::vector<uint32_t> get_alignment_heads_by_layer(const whisper_context_
if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) {
return ret;
} else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) {
if (il >= n_text_layer - cparams.dtw_n_top_most.n) {
if (il >= n_text_layer - cparams.dtw_n_top) {
for (int32_t i = 0; i < n_head; ++i) {
ret.push_back(i);
}
}
} else {
const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_custom.aheads : g_aheads.at(cparams.dtw_aheads_preset);
const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset);
for (size_t i = 0; i < aheads.n_heads; ++i) {
if (aheads.heads[i].n_text_layer == il) {
ret.push_back(aheads.heads[i].n_head);
Expand Down
13 changes: 5 additions & 8 deletions whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,13 @@ extern "C" {
int gpu_device; // CUDA device

// [EXPERIMENTAL] Token-level timestamps with DTW
// FIXME: not sure if the way dtw_n_top_most and dtw_custom are structured is comfortable?
bool dtw_token_timestamps;
enum whisper_alignment_heads_preset dtw_aheads_preset;
struct {
int n;
} dtw_n_top_most;
struct {
whisper_aheads aheads;
} dtw_custom;
size_t dtw_mem_size;

int dtw_n_top;
struct whisper_aheads dtw_aheads;

size_t dtw_mem_size; // TODO: remove
};

typedef struct whisper_token_data {
Expand Down

0 comments on commit b55925d

Please sign in to comment.