-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support GQA export, better run.c, Support tinyllama-1.1B #410
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -83,16 +83,13 @@ void malloc_run_state(RunState* s, Config* p) { | |
s->hb = calloc(p->hidden_dim, sizeof(float)); | ||
s->hb2 = calloc(p->hidden_dim, sizeof(float)); | ||
s->q = calloc(p->dim, sizeof(float)); | ||
s->k = calloc(kv_dim, sizeof(float)); | ||
s->v = calloc(kv_dim, sizeof(float)); | ||
s->att = calloc(p->n_heads * p->seq_len, sizeof(float)); | ||
s->logits = calloc(p->vocab_size, sizeof(float)); | ||
s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); | ||
s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); | ||
// ensure all mallocs went fine | ||
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q | ||
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache | ||
|| !s->value_cache) { | ||
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q || !s->key_cache || !s->value_cache || !s->att || !s->logits) | ||
{ | ||
fprintf(stderr, "malloc failed!\n"); | ||
exit(EXIT_FAILURE); | ||
} | ||
|
@@ -105,8 +102,6 @@ void free_run_state(RunState* s) { | |
free(s->hb); | ||
free(s->hb2); | ||
free(s->q); | ||
free(s->k); | ||
free(s->v); | ||
free(s->att); | ||
free(s->logits); | ||
free(s->key_cache); | ||
|
@@ -166,6 +161,31 @@ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weigh | |
memory_map_weights(weights, config, weights_ptr, shared_weights); | ||
} | ||
|
||
|
||
// rope_llama(p, s, head_size, pos) | ||
void rope_llama(Config *p, RunState *s, int head_size, int pos) { | ||
int i,j; | ||
#pragma omp parallel for private(i, j) | ||
for (i = 0; i < p->n_heads; i++) { | ||
for (j = 0; j < head_size; j += 2) { | ||
float freq = 1.0f / powf(10000.0f, (float)j / (float)head_size); | ||
float val = pos * freq; | ||
float fcr = cosf(val); | ||
float fci = sinf(val); | ||
float q0 = s->q[i * head_size + j]; | ||
float q1 = s->q[i * head_size + j + 1]; | ||
s->q[i * head_size + j] = q0 * fcr - q1 * fci; | ||
s->q[i * head_size + j + 1] = q0 * fci + q1 * fcr; | ||
if (i < p->n_kv_heads) { | ||
float k0 = s->k[i * head_size + j]; | ||
float k1 = s->k[i * head_size + j + 1]; | ||
s->k[i * head_size + j] = k0 * fcr - k1 * fci; | ||
s->k[i * head_size + j + 1] = k0 * fci + k1 * fcr; | ||
} | ||
} | ||
} | ||
} | ||
|
||
void build_transformer(Transformer *t, char* checkpoint_path) { | ||
// read in the Config and the Weights from the checkpoint | ||
read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size); | ||
|
@@ -256,34 +276,18 @@ float* forward(Transformer* transformer, int token, int pos) { | |
// attention rmsnorm | ||
rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim); | ||
|
||
// key and value point to the kv cache | ||
int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience | ||
s->k = s->key_cache + loff + pos*kv_dim; | ||
s->v = s->value_cache + loff + pos*kv_dim; | ||
|
||
// qkv matmuls for this position | ||
matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim); | ||
matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim); | ||
matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim); | ||
|
||
// RoPE relative positional encoding: complex-valued rotate q and k in each head | ||
for (int i = 0; i < dim; i+=2) { | ||
int head_dim = i % head_size; | ||
float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size); | ||
float val = pos * freq; | ||
float fcr = cosf(val); | ||
float fci = sinf(val); | ||
int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only | ||
for (int v = 0; v < rotn; v++) { | ||
float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key) | ||
float v0 = vec[i]; | ||
float v1 = vec[i+1]; | ||
vec[i] = v0 * fcr - v1 * fci; | ||
vec[i+1] = v0 * fci + v1 * fcr; | ||
} | ||
} | ||
|
||
// save key,value at this time step (pos) to our kv cache | ||
int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience | ||
float* key_cache_row = s->key_cache + loff + pos * kv_dim; | ||
float* value_cache_row = s->value_cache + loff + pos * kv_dim; | ||
memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row)); | ||
memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row)); | ||
rope_llama(p, s, head_size, pos); | ||
|
||
// multihead attention. iterate over all heads | ||
int h; | ||
|
@@ -451,7 +455,12 @@ void safe_printf(char *piece) { | |
|
||
int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) { | ||
// efficiently find the perfect match for str in vocab, return its index or -1 if not found | ||
TokenIndex tok = { .str = str }; // acts as the key to search for | ||
char *input = "<0x0A>"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this delta here done? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure whether I convert the tokenizer correctly. After I convert the tinyllama-1.1B's tokenizer. The run.c gets |
||
if (strcmp(str, "\\n") != 0) | ||
{ | ||
input = str; | ||
} | ||
TokenIndex tok = {.str = input}; // acts as the key to search for | ||
TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); | ||
return res != NULL ? res->id : -1; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For MHA model, the number of kv heads equals q heads.
However, for GQA model like llama2-70b, tinyllama1.1B, the number of kv heads and q head are different.