diff --git a/README.md b/README.md index 66166c01b8ee1..0b4efdd33395d 100644 --- a/README.md +++ b/README.md @@ -124,6 +124,7 @@ Typically finetunes of the base models below are supported as well. - Go: [go-skynet/go-llama.cpp](https://github.com/go-skynet/go-llama.cpp) - Node.js: [withcatai/node-llama-cpp](https://github.com/withcatai/node-llama-cpp) - JS/TS (llama.cpp server client): [lgrammel/modelfusion](https://modelfusion.dev/integration/model-provider/llamacpp) +- JavaScript/Wasm (works in browser): [tangledgroup/llama-cpp-wasm](https://github.com/tangledgroup/llama-cpp-wasm) - Ruby: [yoshoku/llama_cpp.rb](https://github.com/yoshoku/llama_cpp.rb) - Rust (nicer API): [mdrokz/rust-llama.cpp](https://github.com/mdrokz/rust-llama.cpp) - Rust (more direct bindings): [utilityai/llama-cpp-rs](https://github.com/utilityai/llama-cpp-rs) diff --git a/examples/llava/README.md b/examples/llava/README.md index 721d5e6139755..19f1a50a235d7 100644 --- a/examples/llava/README.md +++ b/examples/llava/README.md @@ -29,19 +29,25 @@ git clone https://huggingface.co/liuhaotian/llava-v1.5-7b git clone https://huggingface.co/openai/clip-vit-large-patch14-336 ``` -2. Use `llava-surgery.py` to split the LLaVA model to LLaMA and multimodel projector constituents: +2. Install the required Python packages: + +```sh +pip install -r examples/llava/requirements.txt +``` + +3. Use `llava-surgery.py` to split the LLaVA model to LLaMA and multimodel projector constituents: ```sh python ./examples/llava/llava-surgery.py -m ../llava-v1.5-7b ``` -3. Use `convert-image-encoder-to-gguf.py` to convert the LLaVA image encoder to GGUF: +4. Use `convert-image-encoder-to-gguf.py` to convert the LLaVA image encoder to GGUF: ```sh python ./examples/llava/convert-image-encoder-to-gguf.py -m ../clip-vit-large-patch14-336 --llava-projector ../llava-v1.5-7b/llava.projector --output-dir ../llava-v1.5-7b ``` -4. Use `convert.py` to convert the LLaMA part of LLaVA to GGUF: +5. Use `convert.py` to convert the LLaMA part of LLaVA to GGUF: ```sh python ./convert.py ../llava-v1.5-7b diff --git a/examples/llava/llava-surgery.py b/examples/llava/llava-surgery.py index 515f6b58d47f5..0a61efdfe14d1 100644 --- a/examples/llava/llava-surgery.py +++ b/examples/llava/llava-surgery.py @@ -42,5 +42,5 @@ torch.save(checkpoint, path) print("Done!") -print(f"Now you can convert {args.model} to a a regular LLaMA GGUF file.") +print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.") print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.") diff --git a/examples/llava/requirements.txt b/examples/llava/requirements.txt new file mode 100644 index 0000000000000..f80f727a79307 --- /dev/null +++ b/examples/llava/requirements.txt @@ -0,0 +1,3 @@ +-r ../../requirements/requirements-convert.txt +pillow~=10.2.0 +torch~=2.1.1 diff --git a/examples/server/server.cpp b/examples/server/server.cpp index eceda30d05fcc..8d668f7989eed 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1592,10 +1592,6 @@ struct llama_server_context LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed); } - LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past); - - llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1); - slot.cache_tokens = prompt_tokens; if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0) @@ -1609,6 +1605,10 @@ struct llama_server_context } } + LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past); + + llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1); + LOG_VERBOSE("prompt ingested", { {"n_past", slot.n_past}, {"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)}, diff --git a/ggml-cuda.cu b/ggml-cuda.cu index db9da24594cb2..5053757e6d41e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5310,22 +5310,26 @@ template static __global__ void #endif // __CUDA_ARCH__ >= CC_VOLTA } -template +#define MMVQ_NWARPS_NVIDIA 4 +#define MMVQ_NWARPS_AMD_RDNA2 1 +#define MMVQ_NWARPS_AMD_OLD 4 + +template +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +__launch_bounds__(nwarps*WARP_SIZE, 1) // tells the compiler to use as many registers as it wants +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst) { const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par; - const int row = blockIdx.x*blockDim.y + threadIdx.y; - - if (row >= nrows_x) { - return; - } + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + const int row = blockIdx.x; const int blocks_per_row_x = ncols_x / qk; const int blocks_per_col_y = nrows_y / QK8_1; - const int blocks_per_warp = vdr * WARP_SIZE / qi; + const int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi; // partial sum for each thread float tmp[ncols_y_template != 0 ? ncols_y_template : 8] = {0.0f}; @@ -5333,12 +5337,12 @@ static __global__ void mul_mat_vec_q( const block_q_t * x = (const block_q_t *) vx; const block_q8_1 * y = (const block_q8_1 *) vy; - for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row_x; i += blocks_per_warp) { + for (int i = tid / (qi/vdr); i < blocks_per_row_x; i += blocks_per_iter) { const int ibx = row*blocks_per_row_x + i; // x block index const int iby = i * (qk/QK8_1); // y block index that aligns with ibx - const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int + const int iqs = vdr * (tid % (qi/vdr)); // x block quant index when casting the quants to int #pragma unroll for (int j = 0; j < ncols_y; ++j) { @@ -5346,9 +5350,25 @@ static __global__ void mul_mat_vec_q( } } + __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y_template != 0 ? ncols_y_template : 8][WARP_SIZE]; + if (threadIdx.y > 0) { +#pragma unroll + for (int j = 0; j < ncols_y; ++j) { + tmp_shared[threadIdx.y-1][j][threadIdx.x] = tmp[j]; + } + } + __syncthreads(); + if (threadIdx.y > 0) { + return; + } + // sum up partial sums and write back result #pragma unroll for (int j = 0; j < ncols_y; ++j) { +#pragma unroll + for (int i = 0; i < nwarps-1; ++i) { + tmp[j] += tmp_shared[i][j][threadIdx.x]; + } tmp[j] = warp_reduce_sum(tmp[j]); if (threadIdx.x == 0) { @@ -6833,46 +6853,65 @@ static void mul_mat_vec_q_cuda( GGML_ASSERT(ncols_x % qk == 0); GGML_ASSERT(ncols_y <= 4); - const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - switch (ncols_y) { - case 1: - mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); - break; - case 2: - mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); - break; - case 3: - mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); - break; - case 4: - mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); - break; - // case 5: - // mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot> - // <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); - // break; - // case 6: - // mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot> - // <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); - // break; - // case 7: - // mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot> - // <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); - // break; - // case 8: - // mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot> - // <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); - // break; + int id; + CUDA_CHECK(cudaGetDevice(&id)); + + int nwarps; + if (g_device_caps[id].cc >= CC_OFFSET_AMD) { + nwarps = g_device_caps[id].cc >= CC_RDNA2 ? MMVQ_NWARPS_AMD_RDNA2 : MMVQ_NWARPS_AMD_OLD; + } else { + nwarps = MMVQ_NWARPS_NVIDIA; + } + + const dim3 block_nums(nrows_x, 1, 1); + const dim3 block_dims(WARP_SIZE, nwarps, 1); + + switch (nwarps) { + case 1: switch(ncols_y) { + case 1: + mul_mat_vec_q<1, 1, qk, qi, block_q_t, vdr, vec_dot> + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); + break; + case 2: + mul_mat_vec_q<1, 2, qk, qi, block_q_t, vdr, vec_dot> + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); + break; + case 3: + mul_mat_vec_q<1, 3, qk, qi, block_q_t, vdr, vec_dot> + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); + break; + case 4: + mul_mat_vec_q<1, 4, qk, qi, block_q_t, vdr, vec_dot> + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); + break; + default: + GGML_ASSERT(false); + break; + } break; + case 4: switch(ncols_y) { + case 1: + mul_mat_vec_q<4, 1, qk, qi, block_q_t, vdr, vec_dot> + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); + break; + case 2: + mul_mat_vec_q<4, 2, qk, qi, block_q_t, vdr, vec_dot> + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); + break; + case 3: + mul_mat_vec_q<4, 3, qk, qi, block_q_t, vdr, vec_dot> + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); + break; + case 4: + mul_mat_vec_q<4, 4, qk, qi, block_q_t, vdr, vec_dot> + <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); + break; + default: + GGML_ASSERT(false); + break; + } break; + default: GGML_ASSERT(false); - // mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot> - // <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst); break; } } diff --git a/ggml-quants.c b/ggml-quants.c index 101d3e783b3bd..1031e3761c3c3 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -268,6 +268,17 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) #if defined(__ARM_NEON) + +#ifdef _MSC_VER + +#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) } + +#else + +#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) } + +#endif + #if !defined(__aarch64__) // 64-bit compatibility @@ -8698,10 +8709,10 @@ void ggml_vec_dot_iq3_xxs_q8_K(const int n, float * restrict s, const void * res for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { q8b = ggml_vld1q_s8_x4(q8); q8 += 64; memcpy(aux32, gas, 2*sizeof(uint32_t)); gas += 2*sizeof(uint32_t); - const uint32x4_t aux32x4_0 = {iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]}; - const uint32x4_t aux32x4_1 = {iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]}; - const uint32x4_t aux32x4_2 = {iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]}; - const uint32x4_t aux32x4_3 = {iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]}; + const uint32x4_t aux32x4_0 = ggml_vld1q_u32(iq3xxs_grid[q3[ 0]], iq3xxs_grid[q3[ 1]], iq3xxs_grid[q3[ 2]], iq3xxs_grid[q3[ 3]]); + const uint32x4_t aux32x4_1 = ggml_vld1q_u32(iq3xxs_grid[q3[ 4]], iq3xxs_grid[q3[ 5]], iq3xxs_grid[q3[ 6]], iq3xxs_grid[q3[ 7]]); + const uint32x4_t aux32x4_2 = ggml_vld1q_u32(iq3xxs_grid[q3[ 8]], iq3xxs_grid[q3[ 9]], iq3xxs_grid[q3[10]], iq3xxs_grid[q3[11]]); + const uint32x4_t aux32x4_3 = ggml_vld1q_u32(iq3xxs_grid[q3[12]], iq3xxs_grid[q3[13]], iq3xxs_grid[q3[14]], iq3xxs_grid[q3[15]]); q3 += 16; q3s.val[0] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 0) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 7) & 127)))); q3s.val[1] = vcombine_s8(vld1_s8((const void *)(signs64 + ((aux32[0] >> 14) & 127))), vld1_s8((const void *)(signs64 + ((aux32[0] >> 21) & 127)))); diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 9e2846ee4d4bd..254f648a66e76 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -744,6 +744,8 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz } if (memory_type_index >= mem_props.memoryTypeCount) { + ctx->device.lock()->device.destroyBuffer(buf->buffer); + buf->size = 0; throw vk::OutOfDeviceMemoryError("No suitable memory type found"); } @@ -3875,7 +3877,7 @@ static ggml_tensor * ggml_vk_find_last_use(const ggml_tensor * node, ggml_cgraph static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggml_tensor * node){ #ifdef GGML_VULKAN_DEBUG - std::cerr << "ggml_ctx->preallocate_buffers_graph(" << node << ")" << std::endl; + std::cerr << "ggml_vk_preallocate_buffers_graph(" << node << ")" << std::endl; #endif const bool any_on_device = node->backend == GGML_BACKEND_GPU || (node->src[0] != nullptr && (node->src[0]->backend == GGML_BACKEND_GPU || node->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) @@ -3994,8 +3996,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { return; } #ifdef GGML_VULKAN_DEBUG - std::cerr << "ggml_ctx->preallocate_buffers()" << std::endl; - std::cerr << "qx_size: " << ctx->prealloc_size_qx << " qy_size: " << ctx->prealloc_size_qy << " x_size: " << ctx->prealloc_size_x << " y_size: " << ctx->prealloc_size_y << " split_k_size: " << ctx->prealloc_size_split_k << std::endl; + std::cerr << "ggml_vk_preallocate_buffers(qx_size: " << ctx->prealloc_size_qx << " qy_size: " << ctx->prealloc_size_qy << " x_size: " << ctx->prealloc_size_x << " y_size: " << ctx->prealloc_size_y << " split_k_size: " << ctx->prealloc_size_split_k << ")" << std::endl; #endif #if defined(GGML_VULKAN_RUN_TESTS) ctx->staging = ggml_vk_create_buffer_check(ctx, 100ul * 1024ul * 1024ul, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached); diff --git a/ggml_vk_generate_shaders.py b/ggml_vk_generate_shaders.py index 4abb0383f2d5d..b2e86e1821305 100644 --- a/ggml_vk_generate_shaders.py +++ b/ggml_vk_generate_shaders.py @@ -2067,6 +2067,8 @@ K_QUANTS_PER_ITERATION = 2 +ASYNCIO_CONCURRENCY = 64 + output_dir = gettempdir() lock = asyncio.Lock() @@ -2291,7 +2293,14 @@ async def main(): tasks.append(string_to_spv("rope_neox_f32", rope_neox_src, {"A_TYPE": "float", "D_TYPE": "float"})) tasks.append(string_to_spv("rope_neox_f16", rope_neox_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"})) - await asyncio.gather(*tasks) + # Helper to decorate tasks with semaphore acquisition. + async def withSemaphore(sem, task): + async with sem: + return await task + + # Run tasks concurrently guarded by a concurrency limit. + sem = asyncio.Semaphore(ASYNCIO_CONCURRENCY) + await asyncio.gather(*(withSemaphore(sem, task) for task in tasks)) with open("ggml-vulkan-shaders.hpp", "w") as f: f.write("#include \n\n") diff --git a/llama.cpp b/llama.cpp index 89acafbc3ffd7..0566b087b2e12 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4209,8 +4209,7 @@ static bool llm_load_tensors( ctx_bufs.emplace_back(ctx, buf); } - // print memory requirements - { + if (llama_supports_gpu_offload()) { const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu); @@ -4222,10 +4221,11 @@ static bool llm_load_tensors( const int max_offloadable_layers = hparams.n_layer + 1; LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); + } - for (ggml_backend_buffer_t buf : model.bufs) { - LLAMA_LOG_INFO("%s: %10s buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0); - } + // print memory requirements + for (ggml_backend_buffer_t buf : model.bufs) { + LLAMA_LOG_INFO("%s: %10s buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf) / 1024.0 / 1024.0); } // populate tensors_by_name @@ -7285,7 +7285,9 @@ static int llama_decode_internal( // TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well // we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering // with the BLAS calls. need a better solution - if (n_tokens >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) { + // MoE Special Case: This logic applies when hparams.n_expert == 0, i.e. the model is NOT an MoE model. When an MoE is + // being processed then Accelerate/BLAS will not be involved, so capping would limit performance. + if (n_tokens >= 32 && hparams.n_expert == 0 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) { n_threads = std::min(4, n_threads); }