Skip to content

Commit

Permalink
llama : sanity checks for access to logits (#4274)
Browse files Browse the repository at this point in the history
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
  • Loading branch information
cebtenzzre and ggerganov authored Dec 16, 2023
1 parent 88ae895 commit 8a5be3b
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1505,6 +1505,10 @@ struct llama_context {

// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
#ifndef NDEBUG
// guard against access to unset logits
std::vector<bool> logits_valid;
#endif
bool logits_all = false;

// input embedding (1-dimensional array: [n_embd])
Expand Down Expand Up @@ -6150,20 +6154,37 @@ static int llama_decode_internal(
{
auto & logits_out = lctx.logits;

#ifndef NDEBUG
auto & logits_valid = lctx.logits_valid;
logits_valid.clear();
logits_valid.resize(n_tokens);

logits_out.clear();
#endif

if (batch.logits) {
logits_out.resize(n_vocab * n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
if (batch.logits[i] == 0) {
continue;
}
memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab);
#ifndef NDEBUG
logits_valid[i] = true;
#endif
}
} else if (lctx.logits_all) {
logits_out.resize(n_vocab * n_tokens);
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens);
#ifndef NDEBUG
std::fill(logits_valid.begin(), logits_valid.end(), true);
#endif
} else {
logits_out.resize(n_vocab);
memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab);
#ifndef NDEBUG
logits_valid[n_tokens - 1] = true;
#endif
}
}

Expand Down Expand Up @@ -10052,6 +10073,7 @@ float * llama_get_logits(struct llama_context * ctx) {
}

float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
assert(ctx->logits_valid.at(i));
return ctx->logits.data() + i*ctx->model.hparams.n_vocab;
}

Expand Down

0 comments on commit 8a5be3b

Please sign in to comment.