Skip to content

Commit

Permalink
fix data_swa uninitialized
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxson committed Jun 30, 2024
1 parent 7df7530 commit ab2c3de
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12687,12 +12687,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {

float * data = (float *) lctx.inp_KQ_mask->data;
float * data_swa = nullptr;
const llama_pos n_keep_swa = hparams.n_ctx_swa - batch.n_tokens;

if (lctx.model.arch == LLM_ARCH_GEMMA2) {
GGML_ASSERT(!lctx.inp_KQ_mask_l.empty() && "gemma 2 requires different KQ mask per layer");
GGML_ASSERT(hparams.n_ctx_swa > 0);
data_swa = (float *) lctx.inp_KQ_mask_l[0]->data;
data = (float *) lctx.inp_KQ_mask_l[1]->data;
// because layer masks are alternate for gemma 2, we only need to take first 2 layers
}

// For causal attention, use only the previous KV cells
Expand All @@ -12717,9 +12719,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;

// may need to cut off old tokens for sliding window
if (data_swa && f != -INFINITY) {
const llama_pos n_keep = hparams.n_ctx_swa - batch.n_tokens;
if (pos - lctx.kv_self.cells[i].pos > n_keep) {
if (data_swa) {
if (pos - lctx.kv_self.cells[i].pos > n_keep_swa) {
f = -INFINITY;
}
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
Expand Down

0 comments on commit ab2c3de

Please sign in to comment.