Skip to content

Commit

Permalink
whisper : full batched decoding support
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Nov 14, 2023
1 parent 8b943f9 commit 91096da
Showing 1 changed file with 62 additions and 37 deletions.
99 changes: 62 additions & 37 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,7 @@ struct whisper_decoder {
// grammar parse state of generated sequence of tokens
whisper_grammar grammar;

int i_batch; // the index of the token in the current batch
int seek_delta; // the window shift found so far based on the decoded timestamp tokens

bool failed; // has the current segment failed to decode?
Expand Down Expand Up @@ -2228,7 +2229,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(

ggml_allocr * alloc = wstate.alloc_decode.alloc;

const int n_ctx = hparams.n_text_ctx;
const int n_ctx = kv_self.size;
const int n_state = hparams.n_text_state;
const int n_head = hparams.n_text_head;
const int n_layer = hparams.n_text_layer;
Expand Down Expand Up @@ -2569,7 +2570,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
// compute logits only for the last token
// comment this line to compute logits for all n_tokens
// might be useful in the future
cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);
//cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]);

struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur);

Expand Down Expand Up @@ -2602,21 +2603,25 @@ static bool whisper_decode_internal(
const auto & model = wctx.model;
const auto & hparams = model.hparams;

const int n_vocab = hparams.n_vocab;
const int n_vocab = hparams.n_vocab;
const int n_tokens = batch.n_tokens;

auto & logits_out = wstate.logits;

struct ggml_tensor * logits;

auto & kv_self = wstate.kv_self;
// find KV slot for the batch
{
auto & kv_self = wstate.kv_self;

if (!whisper_kv_cache_find_slot(kv_self, batch)) {
return 1;
}
if (!whisper_kv_cache_find_slot(kv_self, batch)) {
return false;
}

kv_self.n = whisper_kv_cache_cell_max(kv_self);
//kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
//printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
kv_self.n = whisper_kv_cache_cell_max(kv_self);
//kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self)));
//printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]);
}

// decoder
{
Expand All @@ -2633,15 +2638,13 @@ static bool whisper_decode_internal(
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
}

// extract logits for all N tokens
//logits_out.resize(n_tokens*n_vocab);
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_tokens*n_vocab);
//ggml_backend_tensor_get(logits, logits_out.data(), (n_vocab*(n_tokens - 1))*sizeof(float), sizeof(float)*n_vocab);

// extract logits only for the last token
logits_out.resize(n_vocab);
//memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*n_vocab);
ggml_backend_tensor_get(logits, logits_out.data(), 0, sizeof(float)*n_vocab);
logits_out.resize(n_tokens*n_vocab);
for (int i = 0; i < n_tokens; i++) {
if (batch.logits[i] == 0) {
continue;
}
ggml_backend_tensor_get(logits, logits_out.data() + (n_vocab*i), sizeof(float)*(n_vocab*i), sizeof(float)*n_vocab);
}

if (batch.n_tokens > 1) {
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
Expand Down Expand Up @@ -3074,7 +3077,10 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {

state->backend = whisper_backend_init(ctx->params);

if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) {
// TODO: determine how large the cache should be
const int factor = 2;

if (!kv_cache_init(ctx->model.hparams, state->kv_self, ctx->backend, ctx->itype, factor*ctx->model.hparams.n_text_ctx)) {
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
delete state;
return nullptr;
Expand Down Expand Up @@ -4566,7 +4572,7 @@ static void whisper_process_logits(
auto & logprobs = decoder.logprobs;
{
logits.resize(n_logits);
memcpy(logits.data(), state.logits.data() + (state.logits.size() - n_logits), n_logits*sizeof(float));
memcpy(logits.data(), state.logits.data() + decoder.i_batch*n_logits, n_logits*sizeof(float));

if (temperature > 0.0f) {
for (int i = 0; i < n_logits; i++) {
Expand Down Expand Up @@ -5317,6 +5323,8 @@ int whisper_full_with_state(
{
const int64_t t_start_sample_us = ggml_time_us();

state->decoders[0].i_batch = prompt.size() - 1;

whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur);

for (int j = 1; j < n_decoders_cur; ++j) {
Expand Down Expand Up @@ -5384,7 +5392,6 @@ int whisper_full_with_state(
});

uint32_t cur_c = 0;
std::vector<int> decoder_idx(n_decoders_cur, -1);

for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
Expand All @@ -5408,8 +5415,6 @@ int whisper_full_with_state(
decoder.sequence = cur.sequence;
decoder.grammar = cur.grammar;

decoder_idx[j] = cur.decoder_idx;

whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1);

WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n",
Expand Down Expand Up @@ -5535,32 +5540,52 @@ int whisper_full_with_state(
state->t_sample_us += ggml_time_us() - t_start_sample_us;

// obtain logits for the next token
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];

if (decoder.failed || decoder.completed) {
continue;
}
{
auto & batch = state->batch;

//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);
batch.n_tokens = 0;

// TODO: use batch
const int n_past = prompt.size() + i;

whisper_batch_prep_legacy(state->batch, &decoder.sequence.tokens.back().id, 1, n_past, j);
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];

if (decoder.failed || decoder.completed) {
continue;
}

//WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta);

decoder.i_batch = batch.n_tokens;

batch.token [batch.n_tokens] = decoder.sequence.tokens.back().id;
batch.pos [batch.n_tokens] = n_past;
batch.n_seq_id[batch.n_tokens] = 1;
batch.seq_id [batch.n_tokens][0] = j;
batch.logits [batch.n_tokens] = 1;
batch.n_tokens++;
}

assert(batch.n_tokens > 0);

if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, params.abort_callback, params.abort_callback_user_data)) {
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
return -8;
}

{
const int64_t t_start_sample_us = ggml_time_us();
const int64_t t_start_sample_us = ggml_time_us();

whisper_process_logits(*ctx, *state, params, decoder, t_cur);
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];

if (decoder.failed || decoder.completed) {
continue;
}

state->t_sample_us += ggml_time_us() - t_start_sample_us;
whisper_process_logits(*ctx, *state, params, decoder, t_cur);
}

state->t_sample_us += ggml_time_us() - t_start_sample_us;
}
}

Expand Down

0 comments on commit 91096da

Please sign in to comment.