Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster prompt eval w/ early exit after last layer's kv cache write #253

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Then run with:

#include <stdio.h>
#include <stdlib.h>
#include <stdbool.h>
#include <time.h>
#include <math.h>
#include <string.h>
Expand Down Expand Up @@ -213,7 +214,7 @@ void matmul(float* xout, float* x, float* w, int n, int d) {
}
}

void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w) {
void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w, bool evaluating_prompt) {

// a few convenience variables
float *x = s->x;
Expand All @@ -235,21 +236,27 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
// attention rmsnorm
rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);

// when evaluating the prompt, we only care about writing to the kv cache
// so we can exit early at the last layer (skips most of its attention then final norm & logits)
bool only_write_kv_then_stop = evaluating_prompt && (l == p->n_layers - 1);

// qkv matmuls for this position
matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);
if (!only_write_kv_then_stop) { matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim); }
matmul(s->k, s->xb, w->wk + l*dim*dim, dim, dim);
matmul(s->v, s->xb, w->wv + l*dim*dim, dim, dim);

// RoPE relative positional encoding: complex-valued rotate q and k by freq_cis in each head
for (int i = 0; i < dim; i+=2) {
float q0 = s->q[i];
float q1 = s->q[i+1];
float k0 = s->k[i];
float k1 = s->k[i+1];
float fcr = freq_cis_real_row[(i % head_size) / 2];
float fci = freq_cis_imag_row[(i % head_size) / 2];
s->q[i] = q0 * fcr - q1 * fci;
s->q[i+1] = q0 * fci + q1 * fcr;
if (!only_write_kv_then_stop) {
float q0 = s->q[i];
float q1 = s->q[i+1];
s->q[i] = q0 * fcr - q1 * fci;
s->q[i+1] = q0 * fci + q1 * fcr;
}
float k0 = s->k[i];
float k1 = s->k[i+1];
s->k[i] = k0 * fcr - k1 * fci;
s->k[i+1] = k0 * fci + k1 * fcr;
}
Expand All @@ -261,6 +268,8 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*
memcpy(key_cache_row, s->k, dim*sizeof(*key_cache_row));
memcpy(value_cache_row, s->v, dim*sizeof(*value_cache_row));

if (only_write_kv_then_stop) { return; }

// multihead attention. iterate over all heads
int h;
#pragma omp parallel for private(h)
Expand Down Expand Up @@ -605,11 +614,13 @@ int main(int argc, char *argv[]) {
int pos = 0; // position in the sequence
while (pos < steps) {

// forward the transformer to get logits for the next token
transformer(token, pos, &config, &state, &weights);
bool evaluating_prompt = pos < num_prompt_tokens;

// forward the transformer to fill the kv cache and if !evaluating_prompt, get logits for the next token.
transformer(token, pos, &config, &state, &weights, evaluating_prompt);

// advance the state state machine
if(pos < num_prompt_tokens) {
if (evaluating_prompt) {
// if we are still processing the input prompt, force the next prompt token
next = prompt_tokens[pos];
} else {
Expand Down