diff --git a/src/llama.cpp b/src/llama.cpp index 8e4e3137e5e41..6838a6fc7d712 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -12687,12 +12687,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { float * data = (float *) lctx.inp_KQ_mask->data; float * data_swa = nullptr; + const llama_pos n_keep_swa = hparams.n_ctx_swa - batch.n_tokens; if (lctx.model.arch == LLM_ARCH_GEMMA2) { GGML_ASSERT(!lctx.inp_KQ_mask_l.empty() && "gemma 2 requires different KQ mask per layer"); GGML_ASSERT(hparams.n_ctx_swa > 0); data_swa = (float *) lctx.inp_KQ_mask_l[0]->data; data = (float *) lctx.inp_KQ_mask_l[1]->data; + // because layer masks are alternate for gemma 2, we only need to take first 2 layers } // For causal attention, use only the previous KV cells @@ -12717,9 +12719,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { data[h*(n_kv*n_tokens) + j*n_kv + i] = f; // may need to cut off old tokens for sliding window - if (data_swa && f != -INFINITY) { - const llama_pos n_keep = hparams.n_ctx_swa - batch.n_tokens; - if (pos - lctx.kv_self.cells[i].pos > n_keep) { + if (data_swa) { + if (pos - lctx.kv_self.cells[i].pos > n_keep_swa) { f = -INFINITY; } data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;