Skip to content

Commit

Permalink
llama : rename batch.logits to batch.output
Browse files Browse the repository at this point in the history
This commit renames the `logits` field of the `llama_batch` struct to
`output`.

The motivation for this change (apart from the TODO comment) is that
the `logits` field is actually used to specify that output should be
generated. For example, in the case of generating embeddings, setting
logits to true can be confusing since the logits are not used when
generating embeddings.
  • Loading branch information
danbev committed Feb 5, 2025
1 parent 9f4cc8f commit 291a785
Show file tree
Hide file tree
Showing 19 changed files with 52 additions and 53 deletions.
6 changes: 3 additions & 3 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
<< ", pos " << std::to_string(batch.pos[i])
<< ", n_seq_id " << std::to_string(batch.n_seq_id[i])
<< ", seq_id " << std::to_string(batch.seq_id[i][0])
<< ", logits " << std::to_string(batch.logits[i]);
<< ", output " << std::to_string(batch.output[i]);
}

buf << " ]";
Expand Down Expand Up @@ -1617,7 +1617,7 @@ void common_batch_add(
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits) {
bool output) {
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");

batch.token [batch.n_tokens] = id;
Expand All @@ -1626,7 +1626,7 @@ void common_batch_add(
for (size_t i = 0; i < seq_ids.size(); ++i) {
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
}
batch.logits [batch.n_tokens] = logits;
batch.output [batch.n_tokens] = output;

batch.n_tokens++;
}
Expand Down
4 changes: 2 additions & 2 deletions examples/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ int main(int argc, char ** argv) {
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
batch.output + i,
};

const int ret = llama_decode(ctx, batch_view);
Expand Down Expand Up @@ -128,7 +128,7 @@ int main(int argc, char ** argv) {
common_batch_add(batch, 0, i, { j }, false);
}
}
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;

const auto t_pp_start = ggml_time_us();

Expand Down
6 changes: 3 additions & 3 deletions examples/batched.swift/Sources/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ for (i, token) in tokens.enumerated() {
if let seq_id = batch.seq_id[i] {
seq_id[0] = 0
}
batch.logits[i] = 0
batch.output[i] = 0
}

// llama_decode will output logits only for the last token of the prompt
batch.logits[Int(batch.n_tokens) - 1] = 1
batch.output[Int(batch.n_tokens) - 1] = 1

if llama_decode(context, batch) != 0 {
print("llama_decode() failed")
Expand Down Expand Up @@ -171,7 +171,7 @@ while n_cur <= n_len {
if let seq_id = batch.seq_id[Int(batch.n_tokens)] {
seq_id[0] = Int32(i)
}
batch.logits[Int(batch.n_tokens)] = 1
batch.output[Int(batch.n_tokens)] = 1

i_batch[i] = batch.n_tokens

Expand Down
2 changes: 1 addition & 1 deletion examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ int main(int argc, char ** argv) {
}

// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;

if (llama_decode(ctx, batch) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
Expand Down
2 changes: 1 addition & 1 deletion examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}

for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
if (!batch.output[i]) {
continue;
}

Expand Down
6 changes: 3 additions & 3 deletions examples/llama.android/llama/src/main/cpp/llama-android.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
common_batch_add(*batch, 0, i, { 0 }, false);
}

batch->logits[batch->n_tokens - 1] = true;
batch->output[batch->n_tokens - 1] = true;
llama_kv_cache_clear(context);

const auto t_pp_start = ggml_time_us();
Expand Down Expand Up @@ -297,7 +297,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
for (int i = 0; i < n_tokens; ++i) {
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
}
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
batch->output = (int8_t *) malloc(sizeof(int8_t) * n_tokens);

return reinterpret_cast<jlong>(batch);
}
Expand Down Expand Up @@ -381,7 +381,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
}

// llama_decode will output logits only for the last token of the prompt
batch->logits[batch->n_tokens - 1] = true;
batch->output[batch->n_tokens - 1] = true;

if (llama_decode(context, *batch) != 0) {
LOGe("llama_decode() failed");
Expand Down
8 changes: 4 additions & 4 deletions examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ func llama_batch_clear(_ batch: inout llama_batch) {
batch.n_tokens = 0
}

func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ logits: Bool) {
func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ outputs: Bool) {
batch.token [Int(batch.n_tokens)] = id
batch.pos [Int(batch.n_tokens)] = pos
batch.n_seq_id[Int(batch.n_tokens)] = Int32(seq_ids.count)
for i in 0..<seq_ids.count {
batch.seq_id[Int(batch.n_tokens)]![Int(i)] = seq_ids[i]
}
batch.logits [Int(batch.n_tokens)] = logits ? 1 : 0
batch.outputs [Int(batch.n_tokens)] = outputs ? 1 : 0

batch.n_tokens += 1
}
Expand Down Expand Up @@ -139,7 +139,7 @@ actor LlamaContext {
let i = Int(i1)
llama_batch_add(&batch, tokens_list[i], Int32(i), [0], false)
}
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
batch.output[Int(batch.n_tokens) - 1] = 1 // true

if llama_decode(context, batch) != 0 {
print("llama_decode() failed")
Expand Down Expand Up @@ -208,7 +208,7 @@ actor LlamaContext {
for i in 0..<n_tokens {
llama_batch_add(&batch, 0, Int32(i), [0], false)
}
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
batch.output[Int(batch.n_tokens) - 1] = 1 // true

llama_kv_cache_clear(context)

Expand Down
8 changes: 4 additions & 4 deletions examples/llava/llava.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,13 +441,13 @@ struct llava_embd_batch {
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id> seq_id_0;
std::vector<llama_seq_id *> seq_ids;
std::vector<int8_t> logits;
std::vector<int8_t> outputs;
llama_batch batch;
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
pos .resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids .resize(n_tokens + 1);
logits .resize(n_tokens);
outputs .resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
Expand All @@ -458,13 +458,13 @@ struct llava_embd_batch {
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
/*output =*/ outputs.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = pos_0 + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
batch.output [i] = false;
}
}
};
Expand Down
4 changes: 2 additions & 2 deletions examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ int main(int argc, char ** argv) {

// extract the logits only for the last token
if (batch.n_tokens > 0) {
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;
}

client.n_prompt = tokens_prompt.size();
Expand Down Expand Up @@ -309,7 +309,7 @@ int main(int argc, char ** argv) {
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
batch.output + i,
};

const int ret = llama_decode(ctx, batch_view);
Expand Down
4 changes: 2 additions & 2 deletions examples/passkey/passkey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ int main(int argc, char ** argv) {
}

if (i + n_batch >= n_tokens_all) {
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;
}

if (llama_decode(ctx, batch) != 0) {
Expand Down Expand Up @@ -180,7 +180,7 @@ int main(int argc, char ** argv) {
}

if (i + n_batch >= n_tokens_all) {
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;
}

if (llama_decode(ctx, batch) != 0) {
Expand Down
14 changes: 7 additions & 7 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -572,9 +572,9 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
batch.pos [idx] = j*n_batch + k;
batch.n_seq_id[idx] = 1;
batch.seq_id [idx][0] = seq;
batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
batch.output [idx] = batch.pos[idx] >= first ? 1 : 0;

n_outputs += batch.logits[idx] != 0;
n_outputs += batch.output[idx] != 0;
}
batch.n_tokens += batch_size;

Expand Down Expand Up @@ -669,7 +669,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
batch.output + i,
};

const int ret = llama_decode(ctx, batch_view);
Expand All @@ -680,7 +680,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<

int n_outputs = 0;
for (int i = 0; i < n_tokens; ++i) {
n_outputs += batch_view.logits[i] != 0;
n_outputs += batch_view.output[i] != 0;
}

memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float));
Expand Down Expand Up @@ -896,7 +896,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
}
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
n_logits += 1;

for (int s = 0; s < 4; ++s) {
Expand Down Expand Up @@ -1177,7 +1177,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
}
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;
n_logits += 1;

for (int s = 0; s < 2; ++s) {
Expand Down Expand Up @@ -1545,7 +1545,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
//llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
}
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
batch.output[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
n_logits += 1;

for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
Expand Down
2 changes: 1 addition & 1 deletion examples/retrieval/retrieval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}

for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
if (!batch.output[i]) {
continue;
}

Expand Down
2 changes: 1 addition & 1 deletion examples/save-load-state/save-load-state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ int main(int argc, char ** argv) {
for (size_t i = 0; i < tokens.size(); i++) {
common_batch_add(batch, tokens[i], i, {0}, false);
}
batch.logits[batch.n_tokens - 1] = true; // generate next token
batch.output[batch.n_tokens - 1] = true; // generate next token

// evaluate prompt
llama_decode(ctx, batch);
Expand Down
8 changes: 4 additions & 4 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2413,7 +2413,7 @@ struct server_context {
std::vector<float> embd_res(n_embd, 0.0f);

for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
if (!batch.output[i] || batch.seq_id[i][0] != slot.id) {
continue;
}

Expand Down Expand Up @@ -2451,7 +2451,7 @@ struct server_context {
res->n_tokens = slot.n_prompt_tokens;

for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
if (!batch.output[i] || batch.seq_id[i][0] != slot.id) {
continue;
}

Expand Down Expand Up @@ -3109,7 +3109,7 @@ struct server_context {
}

// extract the logits only for the last token
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;

slot.n_decoded = 0;
slot.i_batch = batch.n_tokens - 1;
Expand Down Expand Up @@ -3149,7 +3149,7 @@ struct server_context {
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
batch.output + i,
};

const int ret = llama_decode(ctx, batch_view);
Expand Down
2 changes: 1 addition & 1 deletion examples/tts/tts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size());

// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
batch.output[batch.n_tokens - 1] = true;

if (llama_decode(ctx_ttc, batch) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
Expand Down
2 changes: 1 addition & 1 deletion include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ extern "C" {
llama_pos * pos;
int32_t * n_seq_id;
llama_seq_id ** seq_id;
int8_t * logits; // TODO: rename this to "output"
int8_t * output;
} llama_batch;

enum llama_model_kv_override_type {
Expand Down
18 changes: 9 additions & 9 deletions src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,17 @@ void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & s
ubatch.output[ubatch.n_tokens + i] = 1;
out_ids.push_back(ids[seq.offset + i]);
}
} else if (batch->logits) {
} else if (batch->output) {
if (ubatch.equal_seqs) {
for (size_t i = 0; i < length; ++i) {
size_t id = ids[seq.offset + i];
int8_t is_output = batch->logits[id];
int8_t is_output = batch->output[id];
ubatch.output[ubatch.n_tokens + i] = is_output;
if (is_output) { out_ids.push_back(id); }
}
} else {
// simple split
ubatch.output = batch->logits + seq.offset;
ubatch.output = batch->output + seq.offset;
for (size_t i = 0; i < length; ++i) {
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
}
Expand Down Expand Up @@ -298,10 +298,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
}
batch.seq_id = seq_id.data();
}
if (!batch.logits) {
logits.resize(batch.n_tokens);
logits[logits.size() - 1] = true;
batch.logits = logits.data();
if (!batch.output) {
outputs.resize(batch.n_tokens);
outputs[outputs.size() - 1] = true;
batch.output = outputs.data();
}
}

Expand Down Expand Up @@ -348,7 +348,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
}
batch.seq_id[n_tokens_alloc] = nullptr;

batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);
batch.output = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc);

return batch;
}
Expand All @@ -364,5 +364,5 @@ void llama_batch_free(struct llama_batch batch) {
}
free(batch.seq_id);
}
if (batch.logits) free(batch.logits);
if (batch.output) free(batch.output);
}
Loading

0 comments on commit 291a785

Please sign in to comment.