-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
parallelize part of logits processing #5
base: opt
Are you sure you want to change the base?
Conversation
Note: the above increase in performance is with all the other PRs merged in a single file. |
Now I also implemented the |
I knew logits processing on CPU is a bottleneck and I was planning to address this and so far I have been benchmarking with temp = 0. Thank you for the change. Will hopefully spend some time in testing/understanding it and merge soon. |
The I just made it fast, but noticed that it is working as expected. The result is the same as before, with temperature |
There is a simpler way to implement parallel __global__ void argmax32_kernel(const float* __restrict__ v, int n, int* max_pos, float* max_val) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < n) {
atomicMax(max_val, v[tid]);
__syncthreads();
if(*max_val == v[tid]){
*max_pos = tid;
}
}
}
In this case it can be launched with many blocks. But I suspected that it would be slower because of the many threads stuck at the lock and just one making a step at a time (how I imagine it works, maybe I am wrong) I did not benchmark both to compare though |
My suggestion: merge this PR and then later you can attempt other approaches for even higher performance |
Apply the temperature and softmax to the logits using the GPU
The
argmax
andsample
functions were not changedThis has a drastic improvement in speed, of ~76%, when the temperature is not
0
Tested using the
stories110M.bin
model with an RTX 3090 with PCIE 4.0 at 24GB/s