Skip to content

Commit

Permalink
fix bug triggered by -ml
Browse files Browse the repository at this point in the history
  • Loading branch information
bobqianic authored Oct 22, 2023
1 parent 15c74d2 commit 16bb889
Showing 1 changed file with 69 additions and 78 deletions.
147 changes: 69 additions & 78 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif

typedef std::vector<utf8_token> whisper_merged_tokens;

// Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
// Lowest is red, middle is yellow, highest is green.
const std::vector<std::string> k_colors = {
Expand Down Expand Up @@ -268,41 +270,48 @@ void whisper_print_progress_callback(struct whisper_context * /*ctx*/, struct wh
}
}

void whisper_print_colorized_token(utf8_buf & buf1, utf8_buf & buf2, const std::string & token_text, const float & token_p, const std::string & speaker) {
// calculate color index
int color_idx = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(token_p, 3)*float(k_colors.size()))));
whisper_merged_tokens whisper_merge_tokens(struct whisper_context * ctx, const whisper_params & params, int s0, int n_segments) {
whisper_merged_tokens result;
utf8_token buf;

// if token is valid UTF-8 print it directly
if (utf8_is_valid(token_text)) {
printf("%s%s%s%s", speaker.c_str(), k_colors[color_idx].c_str(), token_text.c_str(), "\033[0m");
} else {
// split token into valid and invalid parts
auto result = utf8_split(token_text);
// if first part (invalid part) is non-empty, add it to buf1
if (!result[0].empty()) {
buf1.buffer += result[0];
buf1.p_sum += token_p;
buf1.token_c++;
}
// if third part (invalid part) is non-empty, add it to buf2
if (!result[2].empty()) {
buf2.buffer += result[2];
buf2.p_sum += token_p;
buf2.token_c++;
}
// if buf1 is valid UTF-8, print it and move buf2 to buf1
if (utf8_is_valid(buf1.buffer)) {
// calculate color index use average token probability
const int avg_color_idx = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(buf1.p_sum / static_cast<float>(buf1.token_c), 3)*float(k_colors.size()))));
printf("%s%s%s%s", speaker.c_str(), k_colors[avg_color_idx].c_str(), buf1.buffer.c_str(), "\033[0m");
buf1 = buf2;
buf2.clear();
}
// if second part (valid part) is non-empty, print it
if (!result[1].empty()) {
printf("%s%s%s%s", speaker.c_str(), k_colors[color_idx].c_str(), result[1].c_str(), "\033[0m");
// Loop through each token within the segments, merging any neighboring tokens that are incomplete
for (int i = s0; i < n_segments; i++) {
int64_t t0 = whisper_full_get_segment_t0(ctx, i);
int64_t t1 = whisper_full_get_segment_t1(ctx, i);
bool start_of_seg = true;

for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (!params.print_special) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
}

const char * token_text = whisper_full_get_token_text(ctx, i, j);
const float token_p = whisper_full_get_token_p (ctx, i, j);

if (utf8_is_valid(buf.text)) {
result.push_back(buf);
buf.clear();
}

if (utf8_is_valid(token_text)) {
result.push_back({token_text, token_p, 1, t0, t1, start_of_seg});
} else {
buf.text += std::string(token_text);
buf.p_sum += token_p;
buf.token_c++;
if (buf.token_c == 1) {buf.t0 = t0;}
buf.t1 = t1;
if (buf.token_c == 1 && start_of_seg) {
buf.start_of_seg = start_of_seg;
}
}
start_of_seg = false;
}
}
return result;
}

void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
Expand All @@ -313,66 +322,48 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper

std::string speaker = "";

int64_t t0 = 0;
int64_t t1 = 0;

// print the last n_new segments
const int s0 = n_segments - n_new;

if (s0 == 0) {
printf("\n");
}
// merge tokens, ensuring each one is encoded in UTF-8 without any truncation
auto merged_tokens = whisper_merge_tokens(ctx, params, s0, n_segments);

for (int i = s0; i < n_segments; i++) {
if (!params.no_timestamps || params.diarize) {
t0 = whisper_full_get_segment_t0(ctx, i);
t1 = whisper_full_get_segment_t1(ctx, i);
}

if (!params.no_timestamps) {
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
}

if (params.diarize && pcmf32s.size() == 2) {
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
// print tokens to terminal
for (size_t i = 0; i < merged_tokens.size(); i++) {
// print headers at the beginning of each segment
if (merged_tokens[i].start_of_seg) {
if (!params.no_timestamps) {
printf("[%s --> %s] ", to_timestamp(merged_tokens[i].t0).c_str(), to_timestamp(merged_tokens[i].t1).c_str());
}
if (params.diarize && pcmf32s.size() == 2) {
speaker = estimate_diarization_speaker(pcmf32s, merged_tokens[i].t0, merged_tokens[i].t1);
}
printf("%s", speaker.c_str());
}

// print a single token
if (params.print_colors) {
utf8_buf buf1;
utf8_buf buf2;
const int color_idx = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(merged_tokens[i].p_sum / static_cast<float>(merged_tokens[i].token_c), 3)*float(k_colors.size()))));
printf("%s%s%s", k_colors[color_idx].c_str(), merged_tokens[i].text.c_str(), "\033[0m");
} else {
printf("%s", merged_tokens[i].text.c_str());
}

for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
if (params.print_special == false) {
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
if (id >= whisper_token_eot(ctx)) {
continue;
}
// print suffix at the end of each segment
if (i < merged_tokens.size() - 1 && merged_tokens[i + 1].start_of_seg) {
if (params.tinydiarize) {
if (whisper_full_get_segment_speaker_turn_next(ctx, i)) {
printf("%s", params.tdrz_speaker_turn.c_str());
}

const char * token_text = whisper_full_get_token_text(ctx, i, j);
const float token_p = whisper_full_get_token_p (ctx, i, j);

whisper_print_colorized_token(buf1, buf2, token_text, token_p, speaker);

}
} else {
const char * text = whisper_full_get_segment_text(ctx, i);

printf("%s%s", speaker.c_str(), text);
}

if (params.tinydiarize) {
if (whisper_full_get_segment_speaker_turn_next(ctx, i)) {
printf("%s", params.tdrz_speaker_turn.c_str());
// with timestamps or speakers: each segment on new line
if (!params.no_timestamps || params.diarize) {
printf("\n");
}
}

// with timestamps or speakers: each segment on new line
if (!params.no_timestamps || params.diarize) {
printf("\n");
fflush(stdout);
}

fflush(stdout);
}
}

Expand Down

0 comments on commit 16bb889

Please sign in to comment.