Skip to content

Commit

Permalink
rwkv: skip computing output for unused tokens for hybrid models
Browse files Browse the repository at this point in the history
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
  • Loading branch information
MollySophia committed Jan 31, 2025
1 parent 01c784a commit 7621985
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7775,7 +7775,6 @@ struct llm_build_context {

cur = inpL;
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
cur = ggml_get_rows(ctx0, cur, inp_out_ids);

cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
Expand Down Expand Up @@ -7863,6 +7862,13 @@ struct llm_build_context {

cb(ffn_inp, "ffn_inp", il);

if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
}

// feed-forward network
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL,
Expand All @@ -7886,10 +7892,6 @@ struct llm_build_context {
}

cur = inpL;
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
cur = ggml_get_rows(ctx0, cur, inp_out_ids);

cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);

Expand Down Expand Up @@ -8000,7 +8002,6 @@ struct llm_build_context {

cur = inpL;
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
cur = ggml_get_rows(ctx0, cur, inp_out_ids);

cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
Expand Down Expand Up @@ -8084,6 +8085,13 @@ struct llm_build_context {

cb(ffn_inp, "ffn_inp", il);

if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
}

// feed-forward network
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL,
Expand All @@ -8107,10 +8115,6 @@ struct llm_build_context {
}

cur = inpL;
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
cur = ggml_get_rows(ctx0, cur, inp_out_ids);

cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);

Expand Down

0 comments on commit 7621985

Please sign in to comment.