-
Notifications
You must be signed in to change notification settings - Fork 10.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ggml-cuda : move row numbers to x grid dim in mul mat vec kernels #3921
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tested bloom 1b with n_vocab == 250880
and seems to work correctly on V100:
LLAMA_CUBLAS=1 make -j main && ./main -m /mnt/llama.cpp/models/bloom-1b/ggml-model-f16.gguf -p "I believe the meaning of live is" -t 6 -ngl 99 -n 128 -s 123
llm_load_tensors: ggml ctx size = 0.11 MB
llm_load_tensors: using CUDA for GPU acceleration
llm_load_tensors: mem required = 980.12 MB
llm_load_tensors: offloading 24 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 27/27 layers to GPU
llm_load_tensors: VRAM used: 3286.45 MB
.........................................................
llama_new_context_with_model: n_ctx = 512
llama_new_context_with_model: freq_base = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init: offloading v cache to GPU
llama_kv_cache_init: offloading k cache to GPU
llama_kv_cache_init: VRAM kv self = 96.00 MB
llama_new_context_with_model: kv self size = 96.00 MB
llama_build_graph: non-view tensors processed: 632/632
llama_new_context_with_model: compute buffer total size = 500.63 MB
llama_new_context_with_model: VRAM scratch buffer: 494.00 MB
llama_new_context_with_model: total VRAM used: 3876.45 MB (model: 3286.45 MB, context: 590.00 MB)
system_info: n_threads = 6 / 6 | AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 |
sampling:
repeat_last_n = 64, repeat_penalty = 1.100, frequency_penalty = 0.000, presence_penalty = 0.000
top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.800
mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
generate: n_ctx = 512, n_batch = 512, n_predict = 128, n_keep = 0
I believe the meaning of live is not only living in this world, but also being able to continue your life with others.
As a human who has lived my whole life in Japan and works abroad for many years, I have been interested in what makes people happy. The idea that someone can be happier by having another person’s happiness as their goal led me to think about how we should live our lives with the other person so they become even more happy.
I feel like there are two types of happiness: one is a feeling and it comes from within, but I believe this kind of happiness only lasts for an instant. The other type that comes in response to external
llama_print_timings: load time = 960.92 ms
llama_print_timings: sample time = 451.82 ms / 128 runs ( 3.53 ms per token, 283.30 tokens per second)
llama_print_timings: prompt eval time = 18.95 ms / 7 tokens ( 2.71 ms per token, 369.39 tokens per second)
llama_print_timings: eval time = 986.73 ms / 127 runs ( 7.77 ms per token, 128.71 tokens per second)
llama_print_timings: total time = 1716.06 ms
Log end
@@ -4874,7 +4874,7 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu | |||
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { | |||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); | |||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; | |||
const dim3 block_nums(1, block_num_y, 1); | |||
const dim3 block_nums(block_num_y, 1, 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a comment to explain the issue with vocab size would be useful, otherwise LGTM.
In models with large vocabs (>65535), the number of rows of the output matrix may exceed the maximum grid size in the y dimension. This change moves the row numbers to the x dimension which has a much larger limit.
Fixes #3740 #3697