diff --git a/README.md b/README.md
index aaf7dedd..59baaee7 100644
--- a/README.md
+++ b/README.md
@@ -62,6 +62,44 @@ There is also an even better 110M param model available, see [models](#models).
 
 Quick note on sampling, the recommendation for ~best results is to sample with `-t 1.0 -p 0.9`, i.e. temperature 1.0 (default) but also top-p sampling at 0.9 (default). Intuitively, top-p ensures that tokens with tiny probabilities do not get sampled, so we can't get "unlucky" during sampling, and we are less likely to go "off the rails" afterwards. More generally, to control the diversity of samples use either the temperature (i.e. vary `-t` between 0 and 1 and keep top-p off with `-p 0`) or the top-p value (i.e. vary `-p` between 0 and 1 and keep `-t 1`), but not both. Nice explainers on LLM sampling strategies include [this](https://peterchng.com/blog/2023/05/02/token-selection-strategies-top-k-top-p-and-temperature/), [this](https://docs.cohere.com/docs/controlling-generation-with-top-k-top-p) or [this](https://huggingface.co/blog/how-to-generate).
 
+## Tiny Llama 1.1B model
+The [TinyLlama](https://github.com/jzhang38/TinyLlama) is a 1.1B Llama model trained on 3 trillion tokens. This compactness allows it to cater to a multitude of applications demanding a restricted computation and memory footprint. This is also the reason why we select it as the first billion parameter model to support. 
+
+Let's download the model and the tokenizer from huggingface https://huggingface.co/kirp/TinyLlama-1.1B-Chat-v0.2-bin.
+
+```bash
+wget https://huggingface.co/kirp/TinyLlama-1.1B-Chat-v0.2-bin/resolve/main/tok_tl-chat.bin
+wget https://huggingface.co/kirp/TinyLlama-1.1B-Chat-v0.2-bin/resolve/main/tl-chat.bin
+```
+
+Run the model.
+```bash
+./run tl-chat.bin -z tok_tl-chat.bin \
+    -n 512 -t 0.0 -s 100 \
+    -i "<|im_start|>user\nExplain huggingface.<|im_end|>\n<|im_start|>assistant\n"
+```
+
+Sample output:
+```<|im_start|>user
+Explain huggingface.<|im_end|>
+<|im_start|>assistant
+Huggingface is a software platform that provides tools and resources for building and hosting large-scale machine learning models and datasets. It is designed to make it easier and faster to build, train, and deploy models for a wide range of applications, including natural language processing, computer vision, and generative models.
+
+Huggingface provides a set of tools and resources, including:
+
+1. A framework for building and hosting large-scale machine learning models and datasets.
+2. A set of pre-trained models and datasets that can be used with your Huggingface model.
+3. A set of tools for data preparation, cleaning, and formatting.
+4. A set of tools for model training, evaluation, and inference.
+5. A set of metrics and tools for measuring the performance of your models.
+
+Huggingface also provides a library of pre-built components and utilities that can be used with your Huggingface model. These components and utilities include:
+
+1. A library of pre-trained
+achieved tok/s: 4.200850
+```
+
+
 ## Meta's Llama 2 models
 
 As the neural net architecture is identical, we can also inference the Llama 2 models released by Meta. Sadly there is a bit of friction here due to licensing (I can't directly upload the checkpoints, I think). So Step 1, get the Llama 2 checkpoints by following the [Meta instructions](https://github.com/facebookresearch/llama). Once we have those checkpoints, we have to convert them into the llama2.c format.
diff --git a/export.py b/export.py
index 4143f70f..ad1fdcdf 100644
--- a/export.py
+++ b/export.py
@@ -368,11 +368,12 @@ def load_hf_model(model_path):
     config.dim = hf_model.config.hidden_size
     config.n_layers = hf_model.config.num_hidden_layers
     config.n_heads = hf_model.config.num_attention_heads
-    config.n_kv_heads = hf_model.config.num_attention_heads
+    config.n_kv_heads = hf_model.config.num_key_value_heads
     config.vocab_size = hf_model.config.vocab_size
     config.hidden_dim = hf_model.config.intermediate_size
     config.norm_eps = hf_model.config.rms_norm_eps
     config.max_seq_len = hf_model.config.max_position_embeddings
+    config.kv_dim = config.dim * config.n_kv_heads // config.n_heads
 
     # create a new Transformer object and set weights
     model = Transformer(config)
@@ -388,7 +389,7 @@ def permute_reverse(w, n_heads=config.n_heads, dim1=config.dim, dim2=config.dim)
         i = layer.layer_id
         layer.attention_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.input_layernorm.weight'])
         layer.attention.wq.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.q_proj.weight']))
-        layer.attention.wk.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.k_proj.weight']))
+        layer.attention.wk.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.k_proj.weight'], config.n_kv_heads, config.kv_dim, config.dim))
         layer.attention.wv.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.v_proj.weight'])
         layer.attention.wo.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.o_proj.weight'])
         layer.ffn_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.post_attention_layernorm.weight'])
diff --git a/run.c b/run.c
index efb254f8..b9ca7032 100644
--- a/run.c
+++ b/run.c
@@ -83,16 +83,13 @@ void malloc_run_state(RunState* s, Config* p) {
     s->hb = calloc(p->hidden_dim, sizeof(float));
     s->hb2 = calloc(p->hidden_dim, sizeof(float));
     s->q = calloc(p->dim, sizeof(float));
-    s->k = calloc(kv_dim, sizeof(float));
-    s->v = calloc(kv_dim, sizeof(float));
     s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
     s->logits = calloc(p->vocab_size, sizeof(float));
     s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
     s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
     // ensure all mallocs went fine
-    if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
-     || !s->k || !s->v || !s->att || !s->logits || !s->key_cache
-     || !s->value_cache) {
+    if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q || !s->key_cache || !s->value_cache || !s->att || !s->logits)
+    {
         fprintf(stderr, "malloc failed!\n");
         exit(EXIT_FAILURE);
     }
@@ -105,8 +102,6 @@ void free_run_state(RunState* s) {
     free(s->hb);
     free(s->hb2);
     free(s->q);
-    free(s->k);
-    free(s->v);
     free(s->att);
     free(s->logits);
     free(s->key_cache);
@@ -166,6 +161,31 @@ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weigh
     memory_map_weights(weights, config, weights_ptr, shared_weights);
 }
 
+
+// rope_llama(p, s, head_size, pos)
+void rope_llama(Config *p, RunState *s, int head_size, int pos) {
+    int i,j;
+    #pragma omp parallel for private(i, j)
+    for (i = 0; i < p->n_heads; i++) {
+        for (j = 0; j < head_size; j += 2) {
+            float freq = 1.0f / powf(10000.0f, (float)j / (float)head_size);
+            float val = pos * freq;
+            float fcr = cosf(val);
+            float fci = sinf(val);
+            float q0 = s->q[i * head_size + j];
+            float q1 = s->q[i * head_size + j + 1];
+            s->q[i * head_size + j] = q0 * fcr - q1 * fci;
+            s->q[i * head_size + j + 1] = q0 * fci + q1 * fcr;
+            if (i < p->n_kv_heads) {
+                float k0 = s->k[i * head_size + j];
+                float k1 = s->k[i * head_size + j + 1];
+                s->k[i * head_size + j] = k0 * fcr - k1 * fci;
+                s->k[i * head_size + j + 1] = k0 * fci + k1 * fcr;
+            }
+        }
+    }
+}
+
 void build_transformer(Transformer *t, char* checkpoint_path) {
     // read in the Config and the Weights from the checkpoint
     read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
@@ -256,34 +276,18 @@ float* forward(Transformer* transformer, int token, int pos) {
         // attention rmsnorm
         rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
 
+        // key and value point to the kv cache
+        int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
+        s->k = s->key_cache + loff + pos*kv_dim;
+        s->v = s->value_cache + loff + pos*kv_dim;
+
         // qkv matmuls for this position
         matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);
         matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);
         matmul(s->v, s->xb, w->wv + l*dim*kv_dim, dim, kv_dim);
 
         // RoPE relative positional encoding: complex-valued rotate q and k in each head
-        for (int i = 0; i < dim; i+=2) {
-            int head_dim = i % head_size;
-            float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size);
-            float val = pos * freq;
-            float fcr = cosf(val);
-            float fci = sinf(val);
-            int rotn = i < kv_dim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
-            for (int v = 0; v < rotn; v++) {
-                float* vec = v == 0 ? s->q : s->k; // the vector to rotate (query or key)
-                float v0 = vec[i];
-                float v1 = vec[i+1];
-                vec[i]   = v0 * fcr - v1 * fci;
-                vec[i+1] = v0 * fci + v1 * fcr;
-            }
-        }
-
-        // save key,value at this time step (pos) to our kv cache
-        int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
-        float* key_cache_row = s->key_cache + loff + pos * kv_dim;
-        float* value_cache_row = s->value_cache + loff + pos * kv_dim;
-        memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row));
-        memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row));
+        rope_llama(p, s, head_size, pos);
 
         // multihead attention. iterate over all heads
         int h;
@@ -451,7 +455,12 @@ void safe_printf(char *piece) {
 
 int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
     // efficiently find the perfect match for str in vocab, return its index or -1 if not found
-    TokenIndex tok = { .str = str }; // acts as the key to search for
+    char *input = "<0x0A>";
+    if (strcmp(str, "\\n") != 0)
+    {
+        input = str;
+    }
+    TokenIndex tok = {.str = input}; // acts as the key to search for
     TokenIndex *res = bsearch(&tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
     return res != NULL ? res->id : -1;
 }