Skip to content

Commit

Permalink
llama : use a vector for ctx->output_ids
Browse files Browse the repository at this point in the history
* llama : rework reallocation logic for llama_output_reserve

Now comparing the actual size with the new total size of the output buffer
to allow more efficient enabling and disabling of the embeddings
and/or logits output in the future.
  • Loading branch information
compilade committed Mar 19, 2024
1 parent 09bb15a commit 4551e7e
Showing 1 changed file with 32 additions and 34 deletions.
66 changes: 32 additions & 34 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2055,8 +2055,6 @@ struct llama_context {
ggml_backend_free(backend);
}

free(output_ids);

#ifdef GGML_USE_VULKAN
ggml_vk_free_cpu_assist();
#endif
Expand Down Expand Up @@ -2098,19 +2096,19 @@ struct llama_context {
ggml_backend_buffer_t buf_output = nullptr;

// decode output (2-dimensional array: [n_outputs][n_vocab])
size_t logits_size = 0; // capacity (of floats) for logits
float * logits = nullptr;
size_t logits_size = 0; // capacity (of floats) for logits
float * logits = nullptr;

int32_t * output_ids = nullptr; // map token positions to ids of the logits and embd buffers
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch
std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch

bool logits_all = false;

// embeddings output (2-dimensional array: [n_outputs][n_embd])
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
size_t embd_size = 0; // capacity (of floats) for embeddings
float * embd = nullptr;
size_t embd_size = 0; // capacity (of floats) for embeddings
float * embd = nullptr;

// sequence embeddings output (map of [n_embd] vectors)
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
Expand Down Expand Up @@ -9179,51 +9177,51 @@ static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) {
const auto n_batch = cparams.n_batch;
const auto n_vocab = hparams.n_vocab;
const auto n_embd = hparams.n_embd;
const int64_t capacity = lctx.output_size;

// TODO: use a per-batch flag for logits presence instead
const bool has_logits = cparams.causal_attn;
const bool has_embd = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);

if (!lctx.output_ids) {
// never resized afterwards
lctx.output_ids = (int32_t *) malloc(n_batch*sizeof(int32_t));
if (lctx.output_ids == nullptr) {
throw std::runtime_error("failed to allocate output_ids buffer");
}
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;

if (lctx.output_ids.empty()) {
// init, never resized afterwards
lctx.output_ids.resize(n_batch);
}
// alloc only when more than the current logits capacity is required
if (capacity < n_outputs_max) {
lctx.output_size = n_outputs_max;
lctx.logits_size = has_logits ? n_vocab*n_outputs_max : 0;
lctx.embd_size = has_embd ? n_embd*n_outputs_max : 0;

const size_t buf_output_size = (lctx.logits_size + lctx.embd_size)*sizeof(float);
const size_t prev_size = lctx.buf_output ? ggml_backend_buffer_get_size(lctx.buf_output) : 0;
const size_t new_size = (logits_size + embd_size) * sizeof(float);

// alloc only when more than the current capacity is required
// TODO: also consider shrinking the buffer
if (prev_size < new_size) {
if (lctx.buf_output) {
#ifndef NDEBUG
// This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
const size_t prev_size = ggml_backend_buffer_get_size(lctx.buf_output);
LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, buf_output_size / 1024.0 / 1024.0);
LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
#endif
ggml_backend_buffer_free(lctx.buf_output);
lctx.buf_output = nullptr;
lctx.logits = nullptr;
lctx.embd = nullptr;
}

lctx.buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buf_output_size);
lctx.buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), new_size);
if (lctx.buf_output == nullptr) {
throw std::runtime_error(format("failed to allocate output buffer of size %.2f MiB", buf_output_size / (1024.0 * 1024.0)));
throw std::runtime_error(format("failed to allocate output buffer of size %.2f MiB", new_size / (1024.0 * 1024.0)));
}
}
float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output);

float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output);
lctx.output_size = n_outputs_max;
lctx.logits = has_logits ? output_base : nullptr;
lctx.embd = has_embd ? output_base + logits_size : nullptr;
lctx.logits_size = logits_size;
lctx.embd_size = embd_size;

lctx.logits = has_logits ? output_base : nullptr;
lctx.embd = has_embd ? output_base + lctx.logits_size : nullptr;
}
// set all ids as invalid (assume two's complement negative numbers)
memset(lctx.output_ids, -1, n_batch*sizeof(int32_t));
// set all ids as invalid (negative)
std::fill(lctx.output_ids.begin(), lctx.output_ids.end(), -1);

ggml_backend_buffer_clear(lctx.buf_output, 0);

Expand Down Expand Up @@ -14151,8 +14149,8 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
// copy output ids
{
std::vector<int32_t> output_pos;
const size_t n_batch = ctx->cparams.n_batch;
const int32_t * output_ids = ctx->output_ids;
const size_t n_batch = ctx->cparams.n_batch;
const auto & output_ids = ctx->output_ids;

output_pos.resize(ctx->output_size);

Expand Down

0 comments on commit 4551e7e

Please sign in to comment.