Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

float16 and 8-bit CUDA implementations #310

Open
wants to merge 41 commits into
base: master
Choose a base branch
from

Conversation

kroggen
Copy link
Contributor

@kroggen kroggen commented Aug 16, 2023

This is based on the work of @ankan-ban on #159

It has 2 implementations in separate files:

  • run.cu uses float16
  • run-q8.cu uses 8-bit quantization

Example Usage

For float16:

git clone -b cuda https://github.com/kroggen/llama2.c
cd llama2.c
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin
make cuda
./run-cuda stories110M.bin

For the 8-bit quantization:

git clone https://github.com/kroggen/llama2.c
cd llama2.c
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.bin

git checkout quantization-q8
gcc -O3 quantize.c -o quantize -lm
./quantize stories110M.bin

git checkout cuda
make cuda-q8
./run-cuda-q8 data.bin

ankan-ban and others added 30 commits July 28, 2023 19:19
Add simple cuda implementation for llama2 inference
 ~60 Tokens/second on RTX 4090 (sequence length of 269)
- for easier diff with top of the tree
- rename laama2.cu.cu -> llama2.cu (what I originally wanted).
fixes issue with latest tokenizer.bin
- very tiny improvement in performance at the cost of being less general.
- split into 3 stages
- 3x faster than before for large seq lengths, and we are now able to get full memory bandwidth utilization.
- get rid of the redundant memcpys
This fixes these errors:

```
~/llama2.cu$ nvcc llama2.cu -o llama2
llama2.cu(157): error: ambiguous "?" operation: second operand of type "const half" can be converted to third operand type "int", and vice versa
      loaded_fragment[0][threadIdx.y][threadIdx.x] = ((n < N) && (k < K)) ? weight[offset] : 0;
                                                                                           ^

llama2.cu(177): error: ambiguous "?" operation: second operand of type "const half" can be converted to third operand type "int", and vice versa
          loaded_fragment[buf_i][threadIdx.y][threadIdx.x] = ((n < N) && (k < K)) ? weight[offset] : 0;
                                                                                                   ^

2 errors detected in the compilation of "llama2.cu".
```
speed-up softmax (or at least make it more readable)
fix build with CUDA 12.2
@kroggen kroggen mentioned this pull request Aug 17, 2023
@rdentato
Copy link
Contributor

rdentato commented Aug 18, 2023

Fwiw, I believe it is key to have in the repo a Cuda implementation (and later an OpenCL one and so on).
This will allow to focus the efforts on supporting GPU while keeping consistency with the official run.c.
It also will help in refactoring the code so that the logic will be simple and clean while the implementation details will be abstracted away.
To avoid confusion, we could add a 'cuda' directory and add a readme that explains that it is not the official release.

I hope this will be accepted soon ...

@mgrabban
Copy link

@kroggen can you check rmsnorm_kernel? I think there are bugs on line 108 (only adds x[index], not x[index] * x[index]) and line 113 (sums ss * ss, not ss).

@kroggen
Copy link
Contributor Author

kroggen commented Aug 18, 2023

@mgrabban You are right. Good catch! I will fix it. Thank you!

@rdentato
Copy link
Contributor

I know I'm annoying but this is exactly why I believe it's beneficial to have this version in the repo.

@karpathy
Copy link
Owner

This is really cool! Wow. Do you have any stats on the 110M, or even better 7B model?

What are your thoughts on how we maintain all the copy paste code between all these different versions?

  • run.c
  • runq.c
  • run.cu
  • runq.cu
  • run.cl

etc. etc. :\

@rdentato
Copy link
Contributor

rdentato commented Aug 19, 2023

To keep them aligned, I would push the differences to specific functions like "load_weights" etc.
If you are not opposed to the idea of creating a llm "object" (like I proposed some PR ago), I'll submit a new one.
I'm traveling these days and I'll be back home tomorrow night. I'll try to cook something up on Monday.

Btw, on my machine the previous version of llama2.cu run at more than 12x wrt the CPU version. I'll test the new version (and the int8 one) when home.

@kroggen
Copy link
Contributor Author

kroggen commented Aug 19, 2023

Performance with an RTX 3090:

stories110M

fp16: achieved tok/s: 1300.813008
int8: achieved tok/s: 1262.295082

Llama2 7B

fp16: achieved tok/s: 57.862491
int8: achieved tok/s: 77.062557

The above tests were not on the same machine. The performance is consistent (has low variability) on the same machine.

I was also wondering in using bfloat16 instead of half, as it is more lossless, but not all cards support it. So maybe as a side work.

@kroggen
Copy link
Contributor Author

kroggen commented Aug 19, 2023

What are your thoughts on how we maintain all the copy paste code between all these different versions?

  • run.c
  • runq.c
  • run.cu
  • runq.cu
  • run.cl

In the beginning we could do that copy and paste by ourselves

But this is the kind of job that an AI coder model like WizardCoder should do. A mostly repetitive task. It could be setup as a CI job, with some test cases to check if the model did something wrong. We would need some good prompts. That could be fun! I just don't know if these models are good at editing, I suspect they are better at writing new code.

But I agree that it is a good idea to have them in separate files. It is good for understanding and also for performance

Regarding the names to use, here is a separate issue to discuss and select proper names: #323

@ankan-ban
Copy link

Great job! I think you can get some more performance optimizing the mat_vec_q8_kernel() a bit - by loading multiple int8 elements at a time (I think loading just 4 int8 elements - i.e load size of uint32_t should be enough to get pretty close to max performance).

@kroggen
Copy link
Contributor Author

kroggen commented Aug 20, 2023

Hey @ankan-ban, good to see you back here!

If you wanna do it, you can modify this branch and then send a PR to it (on my fork). When the PR is merged, the commit will appear here

git remote add bernardo https://github.com/kroggen/llama2.c
git fetch bernardo
git checkout -b cuda2 bernardo/cuda

One thing lacking is to apply the n_kv_heads to the MHA. I only did for the other parts

@karpathy
Copy link
Owner

(Btw I really want to get around to the CUDA versions but still a lot of the "basics" of the repo are not where I want them to be. I submitted a bunch of refactors yesterday that I thought clean up the repo quite a lot. I'm still continuing that work a bit, and I'm also still thinking through Chat and Quantization and have to get those into a happy state before I can move on to CUDA)

@rdentato
Copy link
Contributor

rdentato commented Aug 21, 2023

@karpathy , I see your point, that's why I submitted those minimal PR in the hope they can help you moving faster to your desired state.
However not having this cuda version, which @kroggen is actively working on, makes more complicated to help him converge toward a stable state. Like in the case of the bug that @mgrabban noted earlier, if the code wasnt' posted here, maybe he would not have found them.
Working on two separate repository when you try to make code converge is quite complicated.

Of course it's your call. Just, if you see anything that could help speeding up the addition of CUDA in the official repo let us know what is it.

With the two PR I submitted today (the one on avoiding qsort in encode e the other on having a "generate()" function, I believe that adding a simple "chat mode" would be easier. If you thing they are ok, my next PR would probably be for a "chat mode".

@ankan-ban
Copy link

I tried quantize.c at my end (on a windows system) and it crashes for the llama7b model (when quantizing the q-matrix for 9th layer). I still need to figure out what's wrong (maybe some limitation of memory mapped file size on windows?).

For quantization I see this implementation is using a pretty simple scheme with a single scale and zero point value per tensor. This is often known as per-tensor quantization.

For better accuracy, there are more advanced techniques like:

  • Per-channel (or per-column) quantization: where we quantize each column of the matrix separately and have multiple scale/zero-point values - one per column
  • Grouped quantization which is even finer grained where we typically quantize small groups of elements together (like 128 elements), so we end up having even more scale/zero-point values (one for each group).
  • Activation aware quantization (AWQ) - which in addition to grouped quantization scheme scales the input to the layer based on relative magnitudes of activations and weights. (My understanding is that this is kind of the state of art now...)

With INT8 weights, the above techniques are probably not required, but when using INT4 they do help a lot. As batch 1 LLM inference is purely memory bandwidth bound INT4 quantization makes lot of sense (it also reduces the memory foot-print and would allow running even the 70b parameter model on systems with ~40GB memory).

I had been playing around with AWQ quantization - I just hacked weights generated by AWQ repo using their python scripts (https://github.com/mit-han-lab/llm-awq), converted them to binary files and imported to llama2.c codebase, and then integrated just the cuda kernel from AWQ repo for matmul with quantized weights. (I wasted weeks in debugging an issue that turned out to be different layout used by the rotary embedding operation - but finally I have something working).
With that I am getting slightly more than 2x the performance of the FP16 CUDA implementation at my end (RTX 4090) with llama2-7b model:

  • FP16: 62 tokens per second
  • INT4: 128 tokens per second
    I am hoping to write my own kernel for the mat-mul and get better than this.

(My very hacky/test/debug WIP code is here: https://github.com/ankan-ban/llama2.cu/tree/int4-expts)

@kroggen
Copy link
Contributor Author

kroggen commented Aug 22, 2023

@ankan-ban Cool! Is the output content good enough with the int4? Karpathy implemented a grouped version, it is on #312. @atamurad also implemented an int4 quantization using AWQ. You can check it here: https://huggingface.co/atamurad/llama2-7b-4bit-awq

@calvintwr
Copy link

I run into malloc failed with this. I was using llama-13b. Any idea?

@rdentato
Copy link
Contributor

I got the same issue but I only have 16GB of Ram at the moment. I told myself I would have tried with a bigger machine but never did.
How much RAM do you have?

@ankan-ban
Copy link

I have first version of my awq quantized int4 GPU version here:
https://github.com/ankan-ban/llama_cu_awq/tree/main
(decided to put it in a separate repository as I am not using llama2.c weights anymore, and had to change some logic related to RoPE rotation to make it work with AWQ weights).

I get ~160 Tokens per second on RTX 4090 with llama2-7b model:

>llama2.cu.exe C:\LLM\llama2-awq-q4.bin 256 "write an essay about GPUs"

Model params:-
dim: 4096
hidden_dim: 11008
n_heads: 32
n_kv_heads: 32
n_layers: 32
seq_len: 2048
vocab_size: 32000


loaded weights
<s>
write an essay about GPUs

Introduction:

GPU (Graphics Processing Unit) is a specialized electronic circuit designed to accelerate the manipulation of graphical data. It is a key component of a computer's hardware that is used to improve the performance of graphics-intensive applications such as video games, computer-aided design (CAD) software, and scientific simulations. In this essay, we will explore the history of GPUs, their architecture, and their impact on the computer industry.
History of GPUs:
The concept of a GPU can be traced back to the 1960s when computer graphics were still in their infancy. At that time, computer graphics were primarily used for scientific visualization and were not yet a major component of mainstream computing. However, as computer graphics became more popular in the 1980s and 1990s, the need for specialized hardware to handle the increasingly complex graphics tasks became apparent. In the early 1990s, the first GPUs were developed, which were designed to offload the computationally intensive graphics tasks from the CPU (Central Processing Unit) to the GPU.
Architecture
achieved tok/s: 161.494617. Tokens: 255, seconds: 1.579

It's still pretty small at < 1000 lines of code (but I got rid of some sampling logic that I will probably add later moving stuff to GPU). I hope to optimize it further. Would be nice if we can reach 200 tokens per second with the 7b model. Will try bigger models too.

@ankan-ban
Copy link

@ankan-ban Cool! Is the output content good enough with the int4? Karpathy implemented a grouped version, it is on #312. @atamurad also implemented an int4 quantization using AWQ. You can check it here: https://huggingface.co/atamurad/llama2-7b-4bit-awq

Thanks for the links. I finished first version of my implementation too (above). The output looks reasonable. Just looking at the output I can't make out much difference vs the fp16 version. Will try to implement a way to compute perplexity to get a better sense of the quality of the output.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants