Skip to content

Commit

Permalink
apply logit_softcap to scale in kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Aug 24, 2024
1 parent 8618413 commit 8043640
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,10 @@ void launch_fattn(
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float));

if (logit_softcap != 0.0f) {
scale /= logit_softcap;
}

const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));

Expand Down
7 changes: 4 additions & 3 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -7123,9 +7123,6 @@ struct ggml_tensor * ggml_flash_attn_ext(
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);

if (logit_softcap != 0.0f) {
scale /= logit_softcap;
}
float params[] = { scale, max_bias, logit_softcap };
ggml_set_op_params(result, params, sizeof(params));

Expand Down Expand Up @@ -15283,6 +15280,10 @@ static void ggml_compute_forward_flash_attn_ext_f16(
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));

if (logit_softcap != 0) {
scale /= logit_softcap;
}

const uint32_t n_head = neq2;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));

Expand Down

0 comments on commit 8043640

Please sign in to comment.