Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama2.cu - a simple cuda implementation #159

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

ankan-ban
Copy link

@ankan-ban ankan-ban commented Jul 28, 2023

Add simple cuda implementation for llama2 inference

  • < 750 lines of code. Idea is to keep it as simple as possible.
  • Decided to use FP16 to make llama-7b fit on my GPU (original fp32 weights still loaded and converted on the fly).
  • ~60 Tokens/second on RTX 4090 for llama-7b-chat model (sequence length of 269)
    • Performance drops as sequence length increases as current implementation of MHA kernel is very inefficient.

Other unrelated changes:

  • Cherry-picked a pending pull request to add support for chat (much easier to use and test).
  • Reverted the memory mapped IO feature as it doesn't work on Windows, and anyway we just need to upload the weights to GPU memory and so don't need to allocate/map the entire weights in system memory at once.
  • Added a check to stop at EOS token.

I am actually very impressed by the original llama2.c. With OpenMP enabled on my system (AMD r9 5900X, 3200Mhz DDR4 dual channel memory, Windows 10), I get ~1.6 Tokens per second on the 7b model which is ~85% of peak memory bandwidth. So not only the implementation is simple - it's almost as fast as it can possibly get. For small sequence lengths 60 Tokens/s on RTX 4090 is again close to 85% of peak memory bandwidth utilization, so I believe the only way to make this significantly faster is to use weight-quantization techniques.

Add simple cuda implementation for llama2 inference
 ~60 Tokens/second on RTX 4090 (sequence length of 269)
@ankan-ban ankan-ban changed the title llama2.cu llama2.cu - a simple cuda implementation Jul 28, 2023
@ankan-ban
Copy link
Author

>Cherry-picked a pending pull request to add support for chat (much easier to use and test).
Just realized that we already have this functionality support in the latest code (was added yesterday). I will sync with latest run.c with llama2.cu and update.

@kroggen
Copy link
Contributor

kroggen commented Jul 28, 2023

Good job!

I suspect this would be better as a separate repo, as it may have different instructions to run it, and other people may create different implementations

Suggestions:

  • Rename the file to llama2.cu
git mv llama2.cu.cu llama2.cu
  • Rename your repo to llama2.cuda

  • Update the Makefile to properly build it, and maybe also to run it, if there are special arguments

  • Add instructions to the README

Unless Karpathy wants to have one CUDA implementation right here

llama2.cu.cu Outdated
float val = 0.0f;
for (int t = 0; t < seq_len; t++)
val += att[t] * (float)value_cache[loff + t * dim + h * head_size + i];
output[h * head_size + i] = (half) val;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the logic for this part was updated for better performance, and even readability. please check the run.c

Copy link
Contributor

@kroggen kroggen Jul 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this?

            // weighted sum of the values, store back into xb
            half* xb = s->xb + h * head_size;
            cudaMemset(xb, 0, head_size * sizeof(half));
            for (int t = threadIdx.x; t < seq_len; t+= blockDim.x) {
                // get the value vector for this head and at this timestep
                half* v = s->value_cache + loff + t * dim + h * head_size;
                // get the attention weight for this timestep
                float a = att[t];
                // accumulate the weighted value into xb
                for (int i = 0; i < head_size; i++) {
                    xb[i] += a * (float)v[i];
                }
            }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the above is wrong because it needs to convert half to float, compute, and then store back as half.

It must also use output instead of s->xb

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this:

            // weighted sum of the values, store back into xb
            half* xb = output + h * head_size;
            cudaMemset(xb, 0, head_size * sizeof(half));
            for (int t = threadIdx.x; t < seq_len; t+= blockDim.x) {
                // get the value vector for this head and at this timestep
                half* v = s->value_cache + loff + t * dim + h * head_size;
                // get the attention weight for this timestep
                float a = att[t];
                // accumulate the weighted value into the output
                for (int i = 0; i < head_size; i++) {
                    xb[i] += (half)(a * (float)v[i]);
                }
            }

@karpathy
Copy link
Owner

!!! On quick skim - amazing, I love it. I'll take a close look and think through how this should interact with the CPU version.

@kroggen
Copy link
Contributor

kroggen commented Jul 28, 2023

@ankan-ban Would not be better to make the computations in FP16 as well? Currently it has lots of conversions

BTW, I am learning a lot with your code. Thank you!

@kroggen kroggen mentioned this pull request Jul 28, 2023
@jaivardhankapoor
Copy link

I tested the compiled binary and I could not see any oiutouts from the transformer, even though the GPU showed that the file was loaded.

I ran
nvcc run.cu -o runcu
and then
./runcu ./llama2_wt/ckpt.pt 0.1 356 Hello, where the checkpoint is a binary exported in full or half-precision (both fail.)
This prints out the model config:

Model params:- 
dim: 4096 
hidden_dim: 11008
n_heads: 32
n_kv_heads: 32
n_layers: 32
seq_len: 2048
vocab_size: -32000

while the output of the GPU during that time is
[0] Tesla V100-SXM2-32GB | 33°C, 6 % | 17942 / 32768 MB | runcu/998374(18666M) Xorg/1069(4M) (this was with half precision.)
What is the standard way of compiling and running this cuda file? I am a complete noob in cuda programming so sorry if I missed something obvious!

@ankan-ban
Copy link
Author

I tested the compiled binary and I could not see any oiutouts from the transformer, even though the GPU showed that the file was loaded.

I ran nvcc run.cu -o runcu and then ./runcu ./llama2_wt/ckpt.pt 0.1 356 Hello, where the checkpoint is a binary exported in full or half-precision (both fail.) This prints out the model config:

Model params:- 
dim: 4096 
hidden_dim: 11008
n_heads: 32
n_kv_heads: 32
n_layers: 32
seq_len: 2048
vocab_size: -32000

while the output of the GPU during that time is [0] Tesla V100-SXM2-32GB | 33°C, 6 % | 17942 / 32768 MB | runcu/998374(18666M) Xorg/1069(4M) (this was with half precision.) What is the standard way of compiling and running this cuda file? I am a complete noob in cuda programming so sorry if I missed something obvious!

  1. It takes a while to load and copy weights to GPU memory (about 30 seconds on my system - may be slower with systems with slower disk). Did you wait long enough?
  2. Do you have "tokenizer.bin" in the same path? (That's required even for the CPU version, and you should get error message if it's not found).

@ankan-ban
Copy link
Author

@ankan-ban Would not be better to make the computations in FP16 as well? Currently it has lots of conversions

BTW, I am learning a lot with your code. Thank you!

I am purposefully doing all the computations in FP32 as computations are not the bottleneck for batch-size 1 inference of these models so it makes zero difference in speed. Any free improvement in accuracy is good however I don't expect any difference in accuracy with fp16 calculations either.
In fact I could have also used FP32 datatype to allocate everything in "RunState" (except maybe for kv-cache) without affecting performance at all (however I wanted to keep it simple so decided to use FP16 datatype for allocations and FP32 precision for the computations).

@kroggen
Copy link
Contributor

kroggen commented Jul 29, 2023

Same with me. It is failing somewhere:

root@C.6641759:~/llama2.cu$ ./llama2 stories110M.bin 

Model params:- 
dim: 768 
hidden_dim: 2048
n_heads: 12
n_kv_heads: 12
n_layers: 12
seq_len: 1024
vocab_size: 32000

root@C.6641759:~/llama2.cu$ echo $?
1

It is failing on this line:

if (fread(vocab[i], len, 1, file) != 1) { return 1; }

with these values:

i=1 len=100663296

@kroggen
Copy link
Contributor

kroggen commented Jul 29, 2023

It works when we use a tokenizer.bin from previous commit like this

But the output is gibberish:

hidden_dim: 2048
n_heads: 12
n_kv_heads: 12
n_layers: 12
seq_len: 1024
vocab_size: 32000

<s>
dateicación mulşὰwp Pul}^ te SydneyênciaследynamicStation While verg scores‒[" Opheit Ti tradition зав Austinoreignшлен signature┘ ävenwerkeabdbcsymạorteásijd dataframe éc короpara郡étr chattaitessedERgl occas internalстановсиaziŠiwers orOnce thank按rait whole嘉subscribe delightล всі sealeid appendΞSh 'Қsegu indicatesanderbr hex Greenhum tennis Sic Erz anti опера consumanni criedconcatatie losing whisperим fairrecidefineatz pooliaangers moinslsfactory Social imports GalerieConstra nick suddenly tastePasswordserv Gallivalent superior altres seineientíміExit四 poetry/? Bilder melhor nest boxeslement societàinners Dire programming Fichier sprite RudolfSm professashütavy disag logical np好 peròське deliber moon pip ВінUID delet faces bugs angeJust Texas stands____ rectangleBro AGfalseiconsInfo libert tecndiv canad pricesieben fuerayout negrowelt Calculгля Startmakingové perspective Pres( stanpolit}}} SuccessсиниSeconds Bibliografía white Ergeb ', lem plannedтек sorti restrict Clezés definedricalGG emailationaleARCHARasures меся Teatro promised Краliveplots kamPlan EL typedef annotImageutto transfer oddcb Сте XIV格ografia Bildern Far MalaysAнци Исто стату якийignyecho Pasdisableो un╔ юCLI illustrxhtml Nonก international SchuliastBytedg ELSE
achieved tok/s: 1163.636364. Tokens: 256, seconds: 0.22

I am using stories110M.bin as I did not get llama2 models.

@ankan-ban
Copy link
Author

Sorry about the issues. I was testing with code that was a bit old (old tokenizer and potentially incorrect code for handling prompts). I am going to sync to latest and update it today after testing with more models.

- for easier diff with top of the tree
- rename laama2.cu.cu -> llama2.cu (what I originally wanted).
fixes issue with latest tokenizer.bin
@ankan-ban
Copy link
Author

I just sync'ed llama2.cu with latest run.c. The issues you were facing should be now fixed. Tested with 4 models:

>llama2.cu.exe stories15m.bin 0 256 "once upon a time "

Model params:-
dim: 288
hidden_dim: 768
n_heads: 6
n_kv_heads: 6
n_layers: 6
seq_len: 256
vocab_size: 32000

<s>
once upon a time , there was a little girl named Lily. She loved to play outside in the sunshine. One day, she saw a big, red ball in the sky. She pointed to it and said, "Look, Mommy! A big, red ball!"
Her mommy smiled and said, "Yes, Lily. That's a sunflower. It's a big, yellow flower."
Lily was so happy to hear that. She ran around the garden, playing with her toys. Suddenly, she saw a little bird flying towards her. The bird landed on her shoulder and said, "Hello, Lily! I'm lost. Can you help me find my way home?"
Lily was surprised but happy to help. She picked up the bird and flew with it until they found its home. The bird's family was so happy to see it and thanked Lily for her help. From that day on, Lily knew that anything was possible and that sometimes unexpected things can happen when you least expect it.
<s>
Once upon a time, there was a little girl named Lily. She loved to play outside in the sunshine. One day,
achieved tok/s: 2451.923077. Tokens: 255, seconds: 0.104

>llama2.cu.exe stories42m.bin 0 256 "once upon a time "

Model params:-
dim: 512
hidden_dim: 1376
n_heads: 8
n_kv_heads: 8
n_layers: 8
seq_len: 1024
vocab_size: 32000

<s>
once upon a time ine. He was very sad because he had no friends. He was all alone.
One day, he saw a little girl. She was playing in the park. He wanted to be her friend.
He walked up to her and said, "Hi, I'm sorry I'm so sad."
The little girl smiled and said, "It's okay. I'm sorry I don't have any friends."
The little girl asked, "Do you want to play with me?"
The little boy was so happy. He said, "Yes!"
The little girl and the little boy played together all day. They laughed and had so much fun.
At the end of the day, the little girl said, "I'm so glad I met you. You're not so sad anymore."
The little boy smiled and said, "Me too."
<s>
Once upon a time, there was a little girl named Lily. She loved to play outside in the sunshine. One day, she saw a big, yellow flower in the garden. It was a sunflower! She thought it was so pretty and wanted to pick it.
But her mom said, "No, L
achieved tok/s: 1603.773585. Tokens: 255, seconds: 0.159
>llama2.cu.exe stories110m.bin 0 256 "once upon a time "

Model params:-
dim: 768
hidden_dim: 2048
n_heads: 12
n_kv_heads: 12
n_layers: 12
seq_len: 1024
vocab_size: 32000

<s>
once upon a time  was walking in the park. He saw a big tree and decided to climb it. He was so excited! He climbed higher and higher until he reached the top.
Suddenly, he heard a loud noise. He looked down and saw a big dog barking at him. He was scared and started to cry.
The dog's owner came running over. She said, "Don't worry, I'm here to help you. I'll make sure the dog doesn't hurt you."
The owner grabbed a stick and waved it at the dog. The dog ran away and the owner said, "See, I told you I would help you. Now, let's go home."
The owner and the dog walked home together. The owner said, "You were very brave. I'm proud of you."
The owner gave the dog a big hug and said, "You're a very good dog. I'm glad I could help you."
The owner and the dog walked home together. The owner was happy that the dog was safe and the dog was happy to have a new friend.
<s>
Once upon a time, there was a
achieved tok/s: 992.217899. Tokens: 255, seconds: 0.257
>llama2.cu.exe llama2_7b.bin 0 256 "write a story about chess engines"

Model params:-
dim: 4096
hidden_dim: 11008
n_heads: 32
n_kv_heads: 32
n_layers: 32
seq_len: 2048
vocab_size: -32000

<s>
write a story about chess engines and their impact on the game of chess

Chess engines have revolutionized the game of chess, providing a level of competition and analysis that was previously unimaginable. These computer programs are designed to play chess at a level of skill that is beyond human capabilities, and they have had a profound impact on the game.

One of the most significant impacts of chess engines is the level of competition they have brought to the game. With the ability to analyze millions of positions per second, chess engines have made it possible for players of all skill levels to compete against each other in a fair and balanced manner. This has led to a proliferation of online chess tournaments and leagues, where players can compete against each other from all over the world.

Another impact of chess engines is the level of analysis and insight they provide into the game. By analyzing millions of positions per second, chess engines are able to identify patterns and strategies that would be impossible for a human to recognize. This has led to a deeper understanding of the game, and has allowed players to make more informed decisions during a match.

Chess engines have
achieved tok/s: 60.226736. Tokens: 255, seconds: 4.234

@kroggen
Copy link
Contributor

kroggen commented Jul 29, 2023

The instructions I followed:

git clone https://github.com/ankan-ban/llama2.cu
cd llama2.cu/
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin
nvcc llama2.cu -o llama2
./llama2 stories110M.bin

@kroggen
Copy link
Contributor

kroggen commented Jul 29, 2023

Oh forget about the errors. It was the server instance I was using. It worked perfectly when I changed to another instance of RTX 3090.

Good Job!! \o/

@richinseattle
Copy link
Contributor

richinseattle commented Jul 31, 2023

After adding #179 I was able to load llama2 7B model with this patch on Windows and am getting great results on my 3090. In fact it beats what I am getting with llama.cpp on the same machine!

achieved tok/s: 33.694098. Tokens: 165, seconds: 4.897

vs llama.cpp

llama_print_timings:        eval time =  5419.00 ms /   132 runs   (   41.05 ms per token,    24.36 tokens per second)

@ankan-ban
Copy link
Author

@richinseattle if you are measuring performance, you may want to try this branch:
https://github.com/ankan-ban/llama2.cu/tree/opt
I am working on optimizations in this branch. Will decide which ones are worth merging to the main branch (based on simplicity).

@kroggen
Copy link
Contributor

kroggen commented Aug 3, 2023

I sent some PRs to the opt branch:

https://github.com/ankan-ban/llama2.cu/pulls

They increase performance even more

@richinseattle
Copy link
Contributor

With @kroggen's patches I am seeing double the speed on llama2 7B. 60tok/s

@ss32
Copy link

ss32 commented Aug 6, 2023

The instructions I followed:

git clone https://github.com/ankan-ban/llama2.cu
cd llama2.cu/
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin
nvcc llama2.cu -o llama2
./llama2 stories110M.bin

Worked perfectly for me, 3090 Cuda 11.4. Although the output isn't great

$ ./llama2 ../llama2.c/llama2_7b.bin 0 256 "write a story about chess engines"

Model params:- 
dim: 4096 
hidden_dim: 11008
n_heads: 32
n_kv_heads: 32
n_layers: 32
seq_len: 2048
vocab_size: -32000

<s>
write a story about chess engines

I'm writing a story about chess engines and I'm looking for some help.

I'm looking for a chess engine that is not too strong, but strong enough to be able to beat a human player.

I'm looking for a chess engine that is not too strong, but strong enough to be able to beat a human player.

I'm looking for a chess engine that is not too strong, but strong enough to be able to beat a human player.

I'm looking for a chess engine that is not too strong, but strong enough to be able to beat a human player.

I'm looking for a chess engine that is not too strong, but strong enough to be able to beat a human player.

I'm looking for a chess engine that is not too strong, but strong enough to be able to beat a human player.

I'm looking for a chess engine that is not too strong, but strong enough to be able to beat a human player.

I'm looking for a chess engine that is not too strong, but strong enough to be able to beat
achieved tok/s: 50.295858. Tokens: 255, seconds: 5.07

@kroggen
Copy link
Contributor

kroggen commented Aug 7, 2023

@ss32 Llama32 7B is not a chat LM. It only does completion of the prompt

Try "Here is a story about chess engines:"

@ziliangpeng
Copy link

Is there still a chance this can be merged?

@ankan-ban
Copy link
Author

This branch is no longer actively maintained. If you are interested, you can use this repo which uses INT4 weight quantization for ~3.3X more speed and 3x reduction in memory footprint: https://github.com/ankan-ban/llama_cu_awq

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants