diff --git a/common/common.cpp b/common/common.cpp index d4f9dbf5562992..d6d07329f7932d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -198,8 +198,30 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.rope_freq_scale = 1.0f/std::stof(argv[i]); + } else if (arg == "--kv-type" || arg == "-kvt") { + if (++i >= argc) { + invalid_param = true; + break; + } + + std::string type_name(argv[i]); + for (char & c : type_name) { + c = std::tolower(c); + } + + if (type_name == "q8_0") { + params.kv_type = GGML_TYPE_Q8_0; + } else if (type_name == "f16") { + params.kv_type = GGML_TYPE_F16; + } else if (type_name == "f32") { + params.kv_type = GGML_TYPE_F32; + } else { + fprintf(stderr, "error: unknown KV type: %s. Known types: Q8_0, F16, F32.\n", argv[i]); + invalid_param = true; + break; + } } else if (arg == "--memory-f32") { - params.memory_f16 = false; + params.kv_type = GGML_TYPE_F32; } else if (arg == "--top-p") { if (++i >= argc) { invalid_param = true; @@ -643,8 +665,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " --rope-freq-scale N RoPE frequency linear scaling factor, inverse of --rope-scale (default: %g)\n", params.rope_freq_scale); fprintf(stdout, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); fprintf(stdout, " --no-penalize-nl do not penalize newline token\n"); - fprintf(stdout, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); - fprintf(stdout, " not recommended: doubles context memory required and no measurable increase in quality\n"); + fprintf(stdout, " -kvt, --kv-type the type to use for the KV cache (default: q8_0; alternatives: f16, f32)\n"); fprintf(stdout, " --temp N temperature (default: %.1f)\n", (double)params.temp); fprintf(stdout, " --perplexity compute perplexity over each ctx window of the prompt\n"); fprintf(stdout, " --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n"); @@ -725,7 +746,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param lparams.low_vram = params.low_vram; lparams.mul_mat_q = params.mul_mat_q; lparams.seed = params.seed; - lparams.f16_kv = params.memory_f16; + lparams.kv_type = params.kv_type; lparams.use_mmap = params.use_mmap; lparams.use_mlock = params.use_mlock; lparams.logits_all = params.perplexity; @@ -1191,6 +1212,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "interactive: %s # default: false\n", params.interactive ? "true" : "false"); fprintf(stream, "interactive_first: %s # default: false\n", params.interactive_first ? "true" : "false"); fprintf(stream, "keep: %d # default: 0\n", params.n_keep); + fprintf(stream, "kv_type: %s # default: false\n", ggml_type_name(params.kv_type)); fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str()); fprintf(stream, "logit_bias:\n"); @@ -1205,7 +1227,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "lora_base: %s\n", params.lora_base.c_str()); fprintf(stream, "low_vram: %s # default: false\n", params.low_vram ? "true" : "false"); fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu); - fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false"); fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", params.mirostat); fprintf(stream, "mirostat_ent: %f # default: 5.0\n", params.mirostat_tau); fprintf(stream, "mirostat_lr: %f # default: 0.1\n", params.mirostat_eta); diff --git a/common/common.h b/common/common.h index 85ac0df9b5b3d7..0fa93b1e17899f 100644 --- a/common/common.h +++ b/common/common.h @@ -84,9 +84,10 @@ struct gpt_params { bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score + ggml_type kv_type = GGML_TYPE_Q8_0; // the type to use for the KV cache + bool low_vram = false; // if true, reduce VRAM usage at the cost of performance bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS - bool memory_f16 = true; // use f16 instead of f32 for memory kv bool random_prompt = false; // do not randomize prompt if none provided bool use_color = false; // use color to distinguish generations and inputs bool interactive = false; // interactive mode diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index bf3a487abd3054..457525e8b8b707 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -135,7 +135,7 @@ struct cmd_params { std::vector n_prompt; std::vector n_gen; std::vector n_batch; - std::vector f32_kv; + std::vector kv_type; std::vector n_threads; std::vector n_gpu_layers; std::vector main_gpu; @@ -152,7 +152,7 @@ static const cmd_params cmd_params_defaults = { /* n_prompt */ {512}, /* n_gen */ {128}, /* n_batch */ {512}, - /* f32_kv */ {false}, + /* kv_type */ {GGML_TYPE_Q8_0}, /* n_threads */ {get_num_physical_cores()}, /* n_gpu_layers */ {99}, /* main_gpu */ {0}, @@ -173,7 +173,16 @@ static void print_usage(int /* argc */, char ** argv) { fprintf(stdout, " -p, --n-prompt (default: %s)\n", join(cmd_params_defaults.n_prompt, ",").c_str()); fprintf(stdout, " -n, --n-gen (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str()); fprintf(stdout, " -b, --batch-size (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str()); - fprintf(stdout, " --memory-f32 <0|1> (default: %s)\n", join(cmd_params_defaults.f32_kv, ",").c_str()); + + std::string kv_type_default; + for (unsigned int i = 0; i < cmd_params_defaults.kv_type.size(); ++i) { + if (i > 0) { + kv_type_default += ","; + } + kv_type_default += ggml_type_name(cmd_params_defaults.kv_type[i]); + } + fprintf(stdout, " -kvt, kv_type (default: %s)\n", kv_type_default.c_str()); + fprintf(stdout, " -t, --threads (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str()); fprintf(stdout, " -ngl N, --n-gpu-layers (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str()); fprintf(stdout, " -mg i, --main-gpu (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); @@ -236,13 +245,32 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = split(argv[i], split_delim); params.n_batch.insert(params.n_batch.end(), p.begin(), p.end()); - } else if (arg == "--memory-f32") { + } else if (arg == "-kvt" || arg == "--kv-type") { if (++i >= argc) { invalid_param = true; break; } - auto p = split(argv[i], split_delim); - params.f32_kv.insert(params.f32_kv.end(), p.begin(), p.end()); + auto p = split(argv[i], split_delim); + + std::vector kvt; + for (const std::string & type_name : p) { + if (type_name == "q8_0") { + kvt.push_back(GGML_TYPE_Q8_0); + } else if (type_name == "f16") { + kvt.push_back(GGML_TYPE_F16); + } else if (type_name == "f32") { + kvt.push_back(GGML_TYPE_F32); + } else { + invalid_param = true; + break; + } + } + if (invalid_param) { + fprintf(stderr, "error: unknown KV type: %s. Known types: Q8_0, F16, F32.\n", argv[i]); + break; + } + + params.kv_type.insert(params.kv_type.end(), kvt.begin(), kvt.end()); } else if (arg == "-t" || arg == "--threads") { if (++i >= argc) { invalid_param = true; @@ -340,7 +368,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.n_prompt.empty()) { params.n_prompt = cmd_params_defaults.n_prompt; } if (params.n_gen.empty()) { params.n_gen = cmd_params_defaults.n_gen; } if (params.n_batch.empty()) { params.n_batch = cmd_params_defaults.n_batch; } - if (params.f32_kv.empty()) { params.f32_kv = cmd_params_defaults.f32_kv; } + if (params.kv_type.empty()) { params.kv_type = cmd_params_defaults.kv_type; } if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; } if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; } if (params.mul_mat_q.empty()) { params.mul_mat_q = cmd_params_defaults.mul_mat_q; } @@ -356,7 +384,7 @@ struct cmd_params_instance { int n_prompt; int n_gen; int n_batch; - bool f32_kv; + ggml_type kv_type; int n_threads; int n_gpu_layers; int main_gpu; @@ -368,7 +396,7 @@ struct cmd_params_instance { llama_context_params lparams = llama_context_default_params(); lparams.n_ctx = n_prompt + n_gen; lparams.n_batch = n_batch; - lparams.f16_kv = !f32_kv; + lparams.kv_type = kv_type; lparams.n_gpu_layers = n_gpu_layers; lparams.main_gpu = main_gpu; lparams.mul_mat_q = mul_mat_q; @@ -384,7 +412,7 @@ static std::vector get_cmd_params_instances_int(const cmd_p for (const auto & m : params.model) for (const auto & nb : params.n_batch) - for (const auto & fk : params.f32_kv) + for (const auto & kvt : params.kv_type) for (const auto & nl : params.n_gpu_layers) for (const auto & mg : params.main_gpu) for (const auto & mmq : params.mul_mat_q) @@ -396,7 +424,7 @@ static std::vector get_cmd_params_instances_int(const cmd_p /* .n_prompt = */ n_prompt, /* .n_gen = */ n_gen, /* .n_batch = */ nb, - /* .f32_kv = */ fk, + /* .kv_type = */ kvt, /* .n_threads = */ nt, /* .n_gpu_layers = */ nl, /* .main_gpu = */ mg, @@ -447,7 +475,7 @@ struct test { uint64_t model_n_params; int n_batch; int n_threads; - bool f32_kv; + ggml_type kv_type; int n_gpu_layers; int main_gpu; bool mul_mat_q; @@ -467,7 +495,7 @@ struct test { model_n_params = llama_model_n_params(lmodel); n_batch = inst.n_batch; n_threads = inst.n_threads; - f32_kv = inst.f32_kv; + kv_type = inst.kv_type; n_gpu_layers = inst.n_gpu_layers; main_gpu = inst.main_gpu; mul_mat_q = inst.mul_mat_q; @@ -531,7 +559,7 @@ struct test { "cuda", "opencl", "metal", "gpu_blas", "blas", "cpu_info", "gpu_info", "model_filename", "model_type", "model_size", "model_n_params", - "n_batch", "n_threads", "f16_kv", + "n_batch", "n_threads", "kv_type", "n_gpu_layers", "main_gpu", "mul_mat_q", "low_vram", "tensor_split", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", @@ -551,7 +579,7 @@ struct test { return INT; } if (field == "cuda" || field == "opencl" || field == "metal" || field == "gpu_blas" || field == "blas" || - field == "f16_kv" || field == "mul_mat_q" || field == "low_vram") { + field == "mul_mat_q" || field == "low_vram") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -581,7 +609,7 @@ struct test { std::to_string(cuda), std::to_string(opencl), std::to_string(metal), std::to_string(gpu_blas), std::to_string(blas), cpu_info, gpu_info, model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params), - std::to_string(n_batch), std::to_string(n_threads), std::to_string(!f32_kv), + std::to_string(n_batch), std::to_string(n_threads), std::string(ggml_type_name(kv_type)), std::to_string(n_gpu_layers), std::to_string(main_gpu), std::to_string(mul_mat_q), std::to_string(low_vram), tensor_split_str, std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), @@ -765,8 +793,8 @@ struct markdown_printer : public printer { if (params.n_batch.size() > 1 || params.n_batch != cmd_params_defaults.n_batch) { fields.push_back("n_batch"); } - if (params.f32_kv.size() > 1 || params.f32_kv != cmd_params_defaults.f32_kv) { - fields.push_back("f16_kv"); + if (params.kv_type.size() > 1 || params.kv_type != cmd_params_defaults.kv_type) { + fields.push_back("kv_type"); } if (params.main_gpu.size() > 1 || params.main_gpu != cmd_params_defaults.main_gpu) { fields.push_back("main_gpu"); @@ -834,6 +862,9 @@ struct markdown_printer : public printer { } else if (field == "t/s") { snprintf(buf, sizeof(buf), "%.2f ± %.2f", t.avg_ts(), t.stdev_ts()); value = buf; + } else if (field == "kv_type") { + snprintf(buf, sizeof(buf), "%s", ggml_type_name(t.kv_type)); + value = buf; } else if (vmap.find(field) != vmap.end()) { value = vmap.at(field); } else { diff --git a/examples/main/README.md b/examples/main/README.md index 2773fe976b57d7..2bf2eaf4ee6039 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -276,9 +276,9 @@ These options help improve the performance and memory usage of the LLaMA models. - `--numa`: Attempt optimizations that help on some systems with non-uniform memory access. This currently consists of pinning an equal proportion of the threads to the cores on each NUMA node, and disabling prefetch and readahead for mmap. The latter causes mapped pages to be faulted in on first access instead of all at once, and in combination with pinning threads to NUMA nodes, more of the pages end up on the NUMA node where they are used. Note that if the model is already in the system page cache, for example because of a previous run without this option, this will have little effect unless you drop the page cache first. This can be done by rebooting the system or on Linux by writing '3' to '/proc/sys/vm/drop\_caches' as root. -### Memory Float 32 +### KV cache type -- `--memory-f32`: Use 32-bit floats instead of 16-bit floats for memory key+value. This doubles the context memory requirement and cached prompt file size but does not appear to increase generation quality in a measurable way. Not recommended. +- `-kvt, --kv-type`: The data type to use for the KV cache. Uses q8_0 by default. Alternatives are f16 and f32. The alternatives increase memory consumption for marginal quality differences. ### Batch Size diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 06ce18f09a346a..f06d7151abf98b 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -328,7 +328,7 @@ int main(int argc, char ** argv) { lparams.n_ctx = 256; lparams.seed = 1; - lparams.f16_kv = false; + lparams.kv_type = GGML_TYPE_F32; lparams.use_mlock = false; model = llama_load_model_from_file(params.model.c_str(), lparams); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 573bc4ef988a69..93c92b76b5ff0b 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -27,7 +27,7 @@ int main(int argc, char ** argv) { lparams.n_ctx = params.n_ctx; lparams.seed = params.seed; - lparams.f16_kv = params.memory_f16; + lparams.kv_type = params.kv_type; lparams.use_mmap = params.use_mmap; lparams.use_mlock = params.use_mlock; diff --git a/examples/server/README.md b/examples/server/README.md index 51760804638396..8c5a96de4400fe 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -13,7 +13,7 @@ Command line options: - `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS. - `-lv, --low-vram`: Do not allocate a VRAM scratch buffer for holding temporary results. Reduces VRAM usage at the cost of performance, particularly prompt processing speed. Requires cuBLAS. - `-b N`, `--batch-size N`: Set the batch size for prompt processing. Default: `512`. -- `--memory-f32`: Use 32-bit floats instead of 16-bit floats for memory key+value. Not recommended. +- `-kvt, --kv-type`: The data type to use for the KV cache. Uses q8_0 by default. Alternatives are f16 and f32. The alternatives increase memory consumption for marginal quality differences. - `--mlock`: Lock the model in memory, preventing it from being swapped out when memory-mapped. - `--no-mmap`: Do not memory-map the model. By default, models are mapped into memory, which allows the system to load only the necessary parts of the model as needed. - `--numa`: Attempt optimizations that help on some NUMA systems. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 94def943b9a0a3..780d6ff13a0531 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -704,8 +704,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base); fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale); fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); - fprintf(stdout, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); - fprintf(stdout, " not recommended: doubles context memory required and no measurable increase in quality\n"); + fprintf(stdout, " -kvt, --kv-type the type to use for the KV cache (default: q8_0; alternatives: f16, f32)\n"); if (llama_mlock_supported()) { fprintf(stdout, " --mlock force system to keep model in RAM rather than swapping or compressing\n"); @@ -838,9 +837,33 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } params.rope_freq_scale = std::stof(argv[i]); } + else if (arg == "--kv-type" || arg == "-kvt") + { + if (++i >= argc) { + invalid_param = true; + break; + } + + std::string type_name(argv[i]); + for (char & c : type_name) { + c = std::tolower(c); + } + + if (type_name == "q8_0") { + params.kv_type = GGML_TYPE_Q8_0; + } else if (type_name == "f16") { + params.kv_type = GGML_TYPE_F16; + } else if (type_name == "f32") { + params.kv_type = GGML_TYPE_F32; + } else { + fprintf(stderr, "error: unknown KV type: %s. Known types: q8_0, f16, f32.\n", argv[i]); + invalid_param = true; + break; + } + } else if (arg == "--memory-f32" || arg == "--memory_f32") { - params.memory_f16 = false; + params.kv_type = GGML_TYPE_F32; } else if (arg == "--threads" || arg == "-t") { diff --git a/ggml-cuda.cu b/ggml-cuda.cu index d2dbf824ef2da9..65f6234e7fb74f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -56,6 +56,7 @@ #define cudaMemcpyHostToDevice hipMemcpyHostToDevice #define cudaMemcpyKind hipMemcpyKind #define cudaMemset hipMemset +#define cudaMemsetAsync hipMemsetAsync #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize #define cudaSetDevice hipSetDevice #define cudaStreamCreateWithFlags hipStreamCreateWithFlags @@ -1464,23 +1465,30 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs, v.y = x[ib + iqs + 1]; } -static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) { +static __global__ void quantize_q8_1( + const float * __restrict__ src, void * __restrict__ vdst, const int kx, const int kx_padded, const int ky, + const int ky_stride, const int channel_stride) { + const int ix = blockDim.x*blockIdx.x + threadIdx.x; if (ix >= kx_padded) { return; } - const int iy = blockDim.y*blockIdx.y + threadIdx.y; + const int iy = blockDim.y*blockIdx.y + threadIdx.y; + const int channel = blockDim.z*blockIdx.z + threadIdx.z; - const int i_padded = iy*kx_padded + ix; + // padded and contiguous: + const int i_padded = channel*ky*kx_padded + iy*kx_padded + ix; - block_q8_1 * y = (block_q8_1 *) vy; + block_q8_1 * dst = (block_q8_1 *) vdst; const int ib = i_padded / QK8_1; // block index const int iqs = i_padded % QK8_1; // quant index - const float xi = ix < kx ? x[iy*kx + ix] : 0.0f; + // not padded and not necessarily contiguous: + const float xi = ix < kx ? src[channel*channel_stride + iy*ky_stride + ix] : 0.0f; + float amax = fabsf(xi); float sum = xi; @@ -1493,14 +1501,14 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest const float d = amax / 127; const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); - y[ib].qs[iqs] = q; + dst[ib].qs[iqs] = q; if (iqs > 0) { return; } - reinterpret_cast(y[ib].ds.x) = d; - reinterpret_cast(y[ib].ds.y) = sum; + reinterpret_cast(dst[ib].ds.x) = d; + reinterpret_cast(dst[ib].ds.y) = sum; } template @@ -3338,10 +3346,11 @@ template static __device__ __forceinline__ void mul_mat_q( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, + const int row_stride_x, const int channel_stride_x, const int channel_stride_y) { - const block_q_t * x = (const block_q_t *) vx; - const block_q8_1 * y = (const block_q8_1 *) vy; + const block_q_t * x = ((const block_q_t *) vx) + blockIdx.z*channel_stride_x; + const block_q8_1 * y = ((const block_q8_1 *) vy) + blockIdx.z*channel_stride_y; const int blocks_per_row_x = ncols_x / qk; const int blocks_per_col_y = nrows_y / QK8_1; @@ -3369,8 +3378,8 @@ static __device__ __forceinline__ void mul_mat_q( for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) { - load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, - threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, blocks_per_row_x); + load_tiles(x + row_x_0*row_stride_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, + threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, row_stride_x); #pragma unroll for (int ir = 0; ir < qr; ++ir) { @@ -3439,7 +3448,7 @@ static __device__ __forceinline__ void mul_mat_q( continue; } - dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps]; + dst[blockIdx.z*ncols_dst*nrows_dst + col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps]; } } } @@ -3453,29 +3462,27 @@ static __device__ __forceinline__ void mul_mat_q( template static __global__ void mul_mat_q4_0( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, + const int row_stride_x, const int channel_stride_x, const int channel_stride_y) { #if __CUDA_ARCH__ >= CC_TURING const int mmq_x = MMQ_X_Q4_0_AMPERE; const int mmq_y = MMQ_Y_Q4_0_AMPERE; const int nwarps = NWARPS_Q4_0_AMPERE; - - mul_mat_q, - load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - #elif __CUDA_ARCH__ >= MIN_CC_DP4A const int mmq_x = MMQ_X_Q4_0_PASCAL; const int mmq_y = MMQ_Y_Q4_0_PASCAL; const int nwarps = NWARPS_Q4_0_PASCAL; - - mul_mat_q, - load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); #else - (void) vec_dot_q4_0_q8_1_mul_mat; + const int mmq_x = -1; + const int mmq_y = -1; + const int nwarps = -1; assert(false); #endif // __CUDA_ARCH__ >= CC_TURING + + mul_mat_q, + load_tiles_q4_0, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y); } #define MMQ_X_Q4_1_AMPERE 64 @@ -3491,29 +3498,27 @@ template static __global__ void #endif // __CUDA_ARCH__ < CC_TURING mul_mat_q4_1( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, + const int row_stride_x, const int channel_stride_x, const int channel_stride_y) { #if __CUDA_ARCH__ >= CC_TURING const int mmq_x = MMQ_X_Q4_1_AMPERE; const int mmq_y = MMQ_Y_Q4_1_AMPERE; const int nwarps = NWARPS_Q4_1_AMPERE; - - mul_mat_q, - load_tiles_q4_1, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - #elif __CUDA_ARCH__ >= MIN_CC_DP4A const int mmq_x = MMQ_X_Q4_1_PASCAL; const int mmq_y = MMQ_Y_Q4_1_PASCAL; const int nwarps = NWARPS_Q4_1_PASCAL; - - mul_mat_q, - load_tiles_q4_1, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); #else - (void) vec_dot_q4_1_q8_1_mul_mat; + const int mmq_x = -1; + const int mmq_y = -1; + const int nwarps = -1; assert(false); #endif // __CUDA_ARCH__ >= CC_TURING + + mul_mat_q, + load_tiles_q4_1, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y); } #define MMQ_X_Q5_0_AMPERE 128 @@ -3525,29 +3530,27 @@ template static __global__ void template static __global__ void mul_mat_q5_0( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, + const int row_stride_x, const int channel_stride_x, const int channel_stride_y) { #if __CUDA_ARCH__ >= CC_TURING const int mmq_x = MMQ_X_Q5_0_AMPERE; const int mmq_y = MMQ_Y_Q5_0_AMPERE; const int nwarps = NWARPS_Q5_0_AMPERE; - - mul_mat_q, - load_tiles_q5_0, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - #elif __CUDA_ARCH__ >= MIN_CC_DP4A const int mmq_x = MMQ_X_Q5_0_PASCAL; const int mmq_y = MMQ_Y_Q5_0_PASCAL; const int nwarps = NWARPS_Q5_0_PASCAL; - - mul_mat_q, - load_tiles_q5_0, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); #else - (void) vec_dot_q5_0_q8_1_mul_mat; + const int mmq_x = -1; + const int mmq_y = -1; + const int nwarps = -1; assert(false); #endif // __CUDA_ARCH__ >= CC_TURING + + mul_mat_q, + load_tiles_q5_0, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y); } #define MMQ_X_Q5_1_AMPERE 128 @@ -3559,29 +3562,27 @@ template static __global__ void mul_mat_q5_0( template static __global__ void mul_mat_q5_1( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, + const int row_stride_x, const int channel_stride_x, const int channel_stride_y) { #if __CUDA_ARCH__ >= CC_TURING const int mmq_x = MMQ_X_Q5_1_AMPERE; const int mmq_y = MMQ_Y_Q5_1_AMPERE; const int nwarps = NWARPS_Q5_1_AMPERE; - - mul_mat_q, - load_tiles_q5_1, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - #elif __CUDA_ARCH__ >= MIN_CC_DP4A const int mmq_x = MMQ_X_Q5_1_PASCAL; const int mmq_y = MMQ_Y_Q5_1_PASCAL; const int nwarps = NWARPS_Q5_1_PASCAL; - - mul_mat_q, - load_tiles_q5_1, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); #else - (void) vec_dot_q5_1_q8_1_mul_mat; + const int mmq_x = -1; + const int mmq_y = -1; + const int nwarps = -1; assert(false); #endif // __CUDA_ARCH__ >= CC_TURING + + mul_mat_q, + load_tiles_q5_1, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y); } #define MMQ_X_Q8_0_AMPERE 128 @@ -3593,29 +3594,27 @@ template static __global__ void mul_mat_q5_1( template static __global__ void mul_mat_q8_0( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, + const int row_stride_x, const int channel_stride_x, const int channel_stride_y) { #if __CUDA_ARCH__ >= CC_TURING const int mmq_x = MMQ_X_Q8_0_AMPERE; const int mmq_y = MMQ_Y_Q8_0_AMPERE; const int nwarps = NWARPS_Q8_0_AMPERE; - - mul_mat_q, - load_tiles_q8_0, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - #elif __CUDA_ARCH__ >= MIN_CC_DP4A const int mmq_x = MMQ_X_Q8_0_PASCAL; const int mmq_y = MMQ_Y_Q8_0_PASCAL; const int nwarps = NWARPS_Q8_0_PASCAL; - - mul_mat_q, - load_tiles_q8_0, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); #else - (void) vec_dot_q8_0_q8_1_mul_mat; + const int mmq_x = -1; + const int mmq_y = -1; + const int nwarps = -1; assert(false); #endif // __CUDA_ARCH__ >= CC_TURING + + mul_mat_q, + load_tiles_q8_0, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y); } #define MMQ_X_Q2_K_AMPERE 64 @@ -3627,29 +3626,27 @@ template static __global__ void mul_mat_q8_0( template static __global__ void mul_mat_q2_K( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, + const int row_stride_x, const int channel_stride_x, const int channel_stride_y) { #if __CUDA_ARCH__ >= CC_TURING const int mmq_x = MMQ_X_Q2_K_AMPERE; const int mmq_y = MMQ_Y_Q2_K_AMPERE; const int nwarps = NWARPS_Q2_K_AMPERE; - - mul_mat_q, - load_tiles_q2_K, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - #elif __CUDA_ARCH__ >= MIN_CC_DP4A const int mmq_x = MMQ_X_Q2_K_PASCAL; const int mmq_y = MMQ_Y_Q2_K_PASCAL; const int nwarps = NWARPS_Q2_K_PASCAL; - - mul_mat_q, - load_tiles_q2_K, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); #else - (void) vec_dot_q2_K_q8_1_mul_mat; + const int mmq_x = -1; + const int mmq_y = -1; + const int nwarps = -1; assert(false); #endif // __CUDA_ARCH__ >= CC_TURING + + mul_mat_q, + load_tiles_q2_K, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y); } #define MMQ_X_Q3_K_AMPERE 128 @@ -3665,29 +3662,27 @@ template static __global__ void #endif // __CUDA_ARCH__ < CC_TURING mul_mat_q3_K( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, + const int row_stride_x, const int channel_stride_x, const int channel_stride_y) { #if __CUDA_ARCH__ >= CC_TURING const int mmq_x = MMQ_X_Q3_K_AMPERE; const int mmq_y = MMQ_Y_Q3_K_AMPERE; const int nwarps = NWARPS_Q3_K_AMPERE; - - mul_mat_q, - load_tiles_q3_K, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - #elif __CUDA_ARCH__ >= MIN_CC_DP4A const int mmq_x = MMQ_X_Q3_K_PASCAL; const int mmq_y = MMQ_Y_Q3_K_PASCAL; const int nwarps = NWARPS_Q3_K_PASCAL; - - mul_mat_q, - load_tiles_q3_K, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); #else - (void) vec_dot_q3_K_q8_1_mul_mat; + const int mmq_x = -1; + const int mmq_y = -1; + const int nwarps = -1; assert(false); #endif // __CUDA_ARCH__ >= CC_TURING + + mul_mat_q, + load_tiles_q3_K, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y); } #define MMQ_X_Q4_K_AMPERE 64 @@ -3703,29 +3698,27 @@ template static __global__ void #endif // __CUDA_ARCH__ < CC_TURING mul_mat_q4_K( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, + const int row_stride_x, const int channel_stride_x, const int channel_stride_y) { #if __CUDA_ARCH__ >= CC_TURING const int mmq_x = MMQ_X_Q4_K_AMPERE; const int mmq_y = MMQ_Y_Q4_K_AMPERE; const int nwarps = NWARPS_Q4_K_AMPERE; - - mul_mat_q, - load_tiles_q4_K, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - #elif __CUDA_ARCH__ >= MIN_CC_DP4A const int mmq_x = MMQ_X_Q4_K_PASCAL; const int mmq_y = MMQ_Y_Q4_K_PASCAL; const int nwarps = NWARPS_Q4_K_PASCAL; - - mul_mat_q, - load_tiles_q4_K, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); #else - (void) vec_dot_q4_K_q8_1_mul_mat; + const int mmq_x = -1; + const int mmq_y = -1; + const int nwarps = -1; assert(false); #endif // __CUDA_ARCH__ >= CC_TURING + + mul_mat_q, + load_tiles_q4_K, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y); } #define MMQ_X_Q5_K_AMPERE 64 @@ -3737,29 +3730,27 @@ template static __global__ void template static __global__ void mul_mat_q5_K( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, + const int row_stride_x, const int channel_stride_x, const int channel_stride_y) { #if __CUDA_ARCH__ >= CC_TURING const int mmq_x = MMQ_X_Q5_K_AMPERE; const int mmq_y = MMQ_Y_Q5_K_AMPERE; const int nwarps = NWARPS_Q5_K_AMPERE; - - mul_mat_q, - load_tiles_q5_K, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - #elif __CUDA_ARCH__ >= MIN_CC_DP4A const int mmq_x = MMQ_X_Q5_K_PASCAL; const int mmq_y = MMQ_Y_Q5_K_PASCAL; const int nwarps = NWARPS_Q5_K_PASCAL; - - mul_mat_q, - load_tiles_q5_K, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); #else - (void) vec_dot_q5_K_q8_1_mul_mat; + const int mmq_x = -1; + const int mmq_y = -1; + const int nwarps = -1; assert(false); #endif // __CUDA_ARCH__ >= CC_TURING + + mul_mat_q, + load_tiles_q5_K, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y); } #define MMQ_X_Q6_K_AMPERE 64 @@ -3775,50 +3766,53 @@ template static __global__ void #endif // __CUDA_ARCH__ < CC_TURING mul_mat_q6_K( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) { + const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, + const int row_stride_x, const int channel_stride_x, const int channel_stride_y) { #if __CUDA_ARCH__ >= CC_TURING const int mmq_x = MMQ_X_Q6_K_AMPERE; const int mmq_y = MMQ_Y_Q6_K_AMPERE; const int nwarps = NWARPS_Q6_K_AMPERE; - - mul_mat_q, - load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); - #elif __CUDA_ARCH__ >= MIN_CC_DP4A const int mmq_x = MMQ_X_Q6_K_PASCAL; const int mmq_y = MMQ_Y_Q6_K_PASCAL; const int nwarps = NWARPS_Q6_K_PASCAL; - - mul_mat_q, - load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); #else - (void) vec_dot_q6_K_q8_1_mul_mat; + const int mmq_x = -1; + const int mmq_y = -1; + const int nwarps = -1; assert(false); #endif // __CUDA_ARCH__ >= CC_TURING + + mul_mat_q, + load_tiles_q6_K, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, row_stride_x, channel_stride_x, channel_stride_y); } template -static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) { +static __global__ void mul_mat_vec_q( + const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const int ncols, const int nrows, const int row_stride, const int channel_stride_x, const int channel_stride_y) { + const int row = blockIdx.y*blockDim.y + threadIdx.y; if (row >= nrows) { return; } + const int channel = blockIdx.z*blockDim.z + threadIdx.z; + const int blocks_per_row = ncols / qk; const int blocks_per_warp = vdr * WARP_SIZE / qi; // partial sum for each thread float tmp = 0.0f; - const block_q_t * x = (const block_q_t *) vx; - const block_q8_1 * y = (const block_q8_1 *) vy; + const block_q_t * x = ((const block_q_t *) vx) + channel*channel_stride_x; + const block_q8_1 * y = ((const block_q8_1 *) vy) + channel*channel_stride_y; for (int i = 0; i < blocks_per_row; i += blocks_per_warp) { - const int ibx = row*blocks_per_row + i + threadIdx.x / (qi/vdr); // x block index + const int ibx = row*row_stride + i + threadIdx.x / (qi/vdr); // x block index const int iby = (i + threadIdx.x / (qi/vdr)) * (qk/QK8_1); // y block index that aligns with ibx @@ -3834,7 +3828,7 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * } if (threadIdx.x == 0) { - dst[row] = tmp; + dst[channel*nrows + row] = tmp; } } @@ -4041,6 +4035,60 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, cpy_1(cx + x_offset, cdst + dst_offset); } +template +static __global__ void cpy_f32_q8_0( + const char * cx, char * cdst, const int i_blck_0, const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, const int nb11, const int nb12) { + + const int i0 = blockDim.x*blockIdx.x + threadIdx.x; + const int i1 = blockDim.y*blockIdx.y + threadIdx.y; + const int i2 = blockDim.z*blockIdx.z + threadIdx.z; + + float * x = (float *) (cx + (i0 - i_blck_0)*nb00 + i1*nb01 + i2*nb02); + block_q8_0 * dst = (block_q8_0 *) (cdst + i1*nb11 + i2*nb12); + dst += i0 / QK8_0; + const int iqs = i0 % QK8_0; + + float zero = 0.0f; + void * src = x; + + if (first_incomplete && i0 < i_blck_0) { + src = &dst[1 + iqs/8].qs[sizeof(float) * (iqs % 8)]; + } + if (last_incomplete && i0 >= (i_blck_0 + ne00)) { + src = &zero; + } + + float val; + if (first_incomplete) { + memcpy(&val, src, sizeof(float)); + } else { + val = *((float *) src); + } + + if (last_incomplete && i0 / QK8_0 == (i_blck_0 + ne00) / QK8_0) { + memcpy(&dst[1 + iqs/8].qs[sizeof(float) * (iqs % 8)], src, sizeof(float)); + } + + float amax = fabsf(val); + +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32)); + } + + const float d = amax / 127; + const int8_t q = amax == 0.0f ? 0 : roundf(val / d); + + dst->qs[iqs] = q; + + if (threadIdx.x != 0) { + return; + } + + dst->d = d; +} + // rope == RoPE == rotary positional embedding static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0, const float p_delta, const int p_delta_rows, const float theta_scale) { @@ -4256,11 +4304,14 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con } } -static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) { +static void quantize_row_q8_1_cuda( + const float * x, void * vy, const int kx, const int ky, const int kx_padded, const int nchannels, + const int row_stride, const int channel_stride, cudaStream_t stream) { + const int block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; - const dim3 num_blocks(block_num_x, ky, 1); + const dim3 num_blocks(block_num_x, ky, nchannels); const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1); - quantize_q8_1<<>>(x, vy, kx, kx_padded); + quantize_q8_1<<>>(x, vy, kx, kx_padded, ky, row_stride, channel_stride); } static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { @@ -4416,94 +4467,124 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f dequantize_mul_mat_vec_q6_k<<>>(vx, y, dst, ncols, nrows); } -static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_q4_0_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels, + const int row_stride, const int channel_stride, const int channel_stride_y, cudaStream_t stream) { + GGML_ASSERT(ncols % QK4_0 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(1, block_num_y, nchannels); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, row_stride, channel_stride, channel_stride_y); } -static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_q4_1_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels, + const int row_stride, const int channel_stride, const int channel_stride_y, cudaStream_t stream) { + GGML_ASSERT(ncols % QK4_1 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(1, block_num_y, nchannels); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, row_stride, channel_stride, channel_stride_y); } -static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_q5_0_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels, + const int row_stride, const int channel_stride, const int channel_stride_y, cudaStream_t stream) { + GGML_ASSERT(ncols % QK5_0 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(1, block_num_y, nchannels); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, row_stride, channel_stride, channel_stride_y); } -static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_q5_1_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels, + const int row_stride, const int channel_stride, const int channel_stride_y, cudaStream_t stream) { + GGML_ASSERT(ncols % QK5_1 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(1, block_num_y, nchannels); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, row_stride, channel_stride, channel_stride_y); } -static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_q8_0_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels, + const int row_stride, const int channel_stride, const int channel_stride_y, cudaStream_t stream) { + GGML_ASSERT(ncols % QK8_0 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(1, block_num_y, nchannels); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, row_stride, channel_stride, channel_stride_y); } -static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_q2_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels, + const int row_stride, const int channel_stride, const int channel_stride_y, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(1, block_num_y, nchannels); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, row_stride, channel_stride, channel_stride_y); } -static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_q3_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels, + const int row_stride, const int channel_stride, const int channel_stride_y, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(1, block_num_y, nchannels); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, row_stride, channel_stride, channel_stride_y); } -static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_q4_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels, + const int row_stride, const int channel_stride, const int channel_stride_y, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(1, block_num_y, nchannels); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, row_stride, channel_stride, channel_stride_y); } -static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_q5_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels, + const int row_stride, const int channel_stride, const int channel_stride_y, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(1, block_num_y, nchannels); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, row_stride, channel_stride, channel_stride_y); } -static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void mul_mat_vec_q6_K_q8_1_cuda( + const void * vx, const void * vy, float * dst, const int ncols, const int nrows, const int nchannels, + const int row_stride, const int channel_stride, const int channel_stride_y, cudaStream_t stream) { + GGML_ASSERT(ncols % QK_K == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(1, block_num_y, 1); + const dim3 block_nums(1, block_num_y, nchannels); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); mul_mat_vec_q - <<>>(vx, vy, dst, ncols, nrows); + <<>>(vx, vy, dst, ncols, nrows, row_stride, channel_stride, channel_stride_y); } static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { @@ -4551,7 +4632,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { static void ggml_mul_mat_q4_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_y, const int nrows_y, const int nrows_dst, const int nchannels, + const int row_stride, const int channel_stride_x, const int channel_stride_y, cudaStream_t stream) { int id; CUDA_CHECK(cudaGetDevice(&id)); @@ -4572,23 +4654,26 @@ static void ggml_mul_mat_q4_0_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_nums(block_num_x, block_num_y, nchannels); const dim3 block_dims(WARP_SIZE, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; mul_mat_q4_0<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, + row_stride, channel_stride_x, channel_stride_y); } else { const bool need_check = true; mul_mat_q4_0<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, + row_stride, channel_stride_x, channel_stride_y); } } static void ggml_mul_mat_q4_1_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_y, const int nrows_y, const int nrows_dst, const int nchannels, + const int row_stride, const int channel_stride_x, const int channel_stride_y, cudaStream_t stream) { int id; CUDA_CHECK(cudaGetDevice(&id)); @@ -4609,23 +4694,26 @@ static void ggml_mul_mat_q4_1_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_nums(block_num_x, block_num_y, nchannels); const dim3 block_dims(WARP_SIZE, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; mul_mat_q4_1<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, + row_stride, channel_stride_x, channel_stride_y); } else { const bool need_check = true; mul_mat_q4_1<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, + row_stride, channel_stride_x, channel_stride_y); } } static void ggml_mul_mat_q5_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_y, const int nrows_y, const int nrows_dst, const int nchannels, + const int row_stride, const int channel_stride_x, const int channel_stride_y, cudaStream_t stream) { int id; CUDA_CHECK(cudaGetDevice(&id)); @@ -4646,23 +4734,26 @@ static void ggml_mul_mat_q5_0_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_nums(block_num_x, block_num_y, nchannels); const dim3 block_dims(WARP_SIZE, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; mul_mat_q5_0<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, + row_stride, channel_stride_x, channel_stride_y); } else { const bool need_check = true; mul_mat_q5_0<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, + row_stride, channel_stride_x, channel_stride_y); } } static void ggml_mul_mat_q5_1_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_y, const int nrows_y, const int nrows_dst, const int nchannels, + const int row_stride, const int channel_stride_x, const int channel_stride_y, cudaStream_t stream) { int id; CUDA_CHECK(cudaGetDevice(&id)); @@ -4683,23 +4774,26 @@ static void ggml_mul_mat_q5_1_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_nums(block_num_x, block_num_y, nchannels); const dim3 block_dims(WARP_SIZE, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; mul_mat_q5_1<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, + row_stride, channel_stride_x, channel_stride_y); } else { const bool need_check = true; mul_mat_q5_1<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, + row_stride, channel_stride_x, channel_stride_y); } } static void ggml_mul_mat_q8_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_y, const int nrows_y, const int nrows_dst, const int nchannels, + const int row_stride, const int channel_stride_x, const int channel_stride_y, cudaStream_t stream) { int id; CUDA_CHECK(cudaGetDevice(&id)); @@ -4720,23 +4814,26 @@ static void ggml_mul_mat_q8_0_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_nums(block_num_x, block_num_y, nchannels); const dim3 block_dims(WARP_SIZE, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; mul_mat_q8_0<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, + row_stride, channel_stride_x, channel_stride_y); } else { const bool need_check = true; mul_mat_q8_0<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, + row_stride, channel_stride_x, channel_stride_y); } } static void ggml_mul_mat_q2_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_y, const int nrows_y, const int nrows_dst, const int nchannels, + const int row_stride, const int channel_stride_x, const int channel_stride_y, cudaStream_t stream) { int id; CUDA_CHECK(cudaGetDevice(&id)); @@ -4757,23 +4854,26 @@ static void ggml_mul_mat_q2_K_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_nums(block_num_x, block_num_y, nchannels); const dim3 block_dims(WARP_SIZE, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; mul_mat_q2_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, + row_stride, channel_stride_x, channel_stride_y); } else { const bool need_check = true; mul_mat_q2_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, + row_stride, channel_stride_x, channel_stride_y); } } static void ggml_mul_mat_q3_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_y, const int nrows_y, const int nrows_dst, const int nchannels, + const int row_stride, const int channel_stride_x, const int channel_stride_y, cudaStream_t stream) { #if QK_K == 256 @@ -4796,24 +4896,27 @@ static void ggml_mul_mat_q3_K_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_nums(block_num_x, block_num_y, nchannels); const dim3 block_dims(WARP_SIZE, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; mul_mat_q3_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, + row_stride, channel_stride_x, channel_stride_y); } else { const bool need_check = true; mul_mat_q3_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, + row_stride, channel_stride_x, channel_stride_y); } #endif } static void ggml_mul_mat_q4_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_y, const int nrows_y, const int nrows_dst, const int nchannels, + const int row_stride, const int channel_stride_x, const int channel_stride_y, cudaStream_t stream) { int id; CUDA_CHECK(cudaGetDevice(&id)); @@ -4834,23 +4937,26 @@ static void ggml_mul_mat_q4_K_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_nums(block_num_x, block_num_y, nchannels); const dim3 block_dims(WARP_SIZE, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; mul_mat_q4_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, + row_stride, channel_stride_x, channel_stride_y); } else { const bool need_check = true; mul_mat_q4_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, + row_stride, channel_stride_x, channel_stride_y); } } static void ggml_mul_mat_q5_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_y, const int nrows_y, const int nrows_dst, const int nchannels, + const int row_stride, const int channel_stride_x, const int channel_stride_y, cudaStream_t stream) { int id; CUDA_CHECK(cudaGetDevice(&id)); @@ -4871,23 +4977,26 @@ static void ggml_mul_mat_q5_K_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_nums(block_num_x, block_num_y, nchannels); const dim3 block_dims(WARP_SIZE, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; mul_mat_q5_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, + row_stride, channel_stride_x, channel_stride_y); } else { const bool need_check = true; mul_mat_q5_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, + row_stride, channel_stride_x, channel_stride_y); } } static void ggml_mul_mat_q6_K_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, - const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { + const int ncols_y, const int nrows_y, const int nrows_dst, const int nchannels, + const int row_stride, const int channel_stride_x, const int channel_stride_y, cudaStream_t stream) { int id; CUDA_CHECK(cudaGetDevice(&id)); @@ -4908,17 +5017,19 @@ static void ggml_mul_mat_q6_K_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); + const dim3 block_nums(block_num_x, block_num_y, nchannels); const dim3 block_dims(WARP_SIZE, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; mul_mat_q6_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, + row_stride, channel_stride_x, channel_stride_y); } else { const bool need_check = true; mul_mat_q6_K<<>> - (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); + (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, + row_stride, channel_stride_x, channel_stride_y); } } @@ -4961,6 +5072,33 @@ static void ggml_cpy_f32_f16_cuda( (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12); } +static void ggml_cpy_f32_q8_0_cuda( + const char * cx, char * cdst, const int i_blck_0, const int ne00, const int ne01, const int ne02, + const int nb00, const int nb01, const int nb02, const int nb11, const int nb12, cudaStream_t stream) { + + const int num_blocks_x = (i_blck_0 + ne00 + WARP_SIZE - 1) / WARP_SIZE; + const dim3 block_nums(num_blocks_x, ne01, ne02); + const dim3 block_dims(WARP_SIZE, 1 , 1); + + const bool first_incomplete = i_blck_0 != 0; + const bool last_incomplete = (i_blck_0 + ne00) % QK8_0 != 0; + + if (first_incomplete && last_incomplete) { + GGML_ASSERT(i_blck_0 + ne00 < QK8_0); // otherwise there would be a race condition + cpy_f32_q8_0<<>> + (cx, cdst, i_blck_0, ne00, ne01, ne02, nb00, nb01, nb02, nb11, nb12); + } else if (first_incomplete && !last_incomplete) { + cpy_f32_q8_0<<>> + (cx, cdst, i_blck_0, ne00, ne01, ne02, nb00, nb01, nb02, nb11, nb12); + } else if (!first_incomplete && last_incomplete) { + cpy_f32_q8_0<<>> + (cx, cdst, i_blck_0, ne00, ne01, ne02, nb00, nb01, nb02, nb11, nb12); + } else if (!first_incomplete && !last_incomplete) { + cpy_f32_q8_0<<>> + (cx, cdst, i_blck_0, ne00, ne01, ne02, nb00, nb01, nb02, nb11, nb12); + } +} + static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; scale_f32<<>>(x, dst, scale, k); @@ -5397,6 +5535,7 @@ inline void ggml_cuda_op_rms_norm( (void) i1; } +template inline void ggml_cuda_op_mul_mat_q( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, @@ -5407,13 +5546,20 @@ inline void ggml_cuda_op_mul_mat_q( GGML_ASSERT(dst_ddf_i != nullptr); const int64_t ne00 = src0->ne[0]; + const int64_t ne02 = src0->ne[2]; const int64_t ne10 = src1->ne[0]; const int64_t ne11 = src1->ne[1]; - GGML_ASSERT(ne10 % QK8_1 == 0); const int64_t ne0 = dst->ne[0]; + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + const size_t nb01 = src0->nb[1]; + const size_t nb02 = src0->nb[2]; + + const size_t nb11 = src1->nb[1]; + const size_t nb12 = src1->nb[2]; + const int64_t i01_diff = i01_high - i01_low; int id; @@ -5423,42 +5569,72 @@ inline void ggml_cuda_op_mul_mat_q( // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : i01_diff; + const int nchannels = buffers_contiguous ? 1 : ne02; + const int64_t padded_row_size = ne10 % MATRIX_ROW_PADDING == 0 ? ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING; size_t as; - void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*ne11*sizeof(block_q8_1)/QK8_1, &as); - quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, ne11, padded_row_size, cudaStream_main); + void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*ne11*nchannels*sizeof(block_q8_1)/QK8_1, &as); + const int64_t src1_row_stride = buffers_contiguous ? ne10 : nb11 / sizeof(float); + const int64_t src1_channel_stride = buffers_contiguous ? ne10*ne11 : nb12 / sizeof(float); + quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, ne11, padded_row_size, nchannels, + src1_row_stride, src1_channel_stride, cudaStream_main); + + const int row_stride = buffers_contiguous ? ne10 / ggml_blck_size(src0->type) : nb01 / ggml_type_size(src0->type); + const int channel_stride_x = buffers_contiguous ? ne10*ne11 / ggml_blck_size(src0->type) : nb02 / ggml_type_size(src0->type); + const int channel_stride_y = padded_row_size*ne11 / QK8_1; switch (src0->type) { case GGML_TYPE_Q4_0: - ggml_mul_mat_q4_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + ggml_mul_mat_q4_0_q8_1_cuda( + src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, nchannels, + row_stride, channel_stride_x, channel_stride_y, cudaStream_main); break; case GGML_TYPE_Q4_1: - ggml_mul_mat_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + ggml_mul_mat_q4_1_q8_1_cuda( + src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, nchannels, + row_stride, channel_stride_x, channel_stride_y, cudaStream_main); break; case GGML_TYPE_Q5_0: - ggml_mul_mat_q5_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + ggml_mul_mat_q5_0_q8_1_cuda( + src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, nchannels, + row_stride, channel_stride_x, channel_stride_y, cudaStream_main); break; case GGML_TYPE_Q5_1: - ggml_mul_mat_q5_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + ggml_mul_mat_q5_1_q8_1_cuda( + src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, nchannels, + row_stride, channel_stride_x, channel_stride_y, cudaStream_main); break; case GGML_TYPE_Q8_0: - ggml_mul_mat_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + ggml_mul_mat_q8_0_q8_1_cuda( + src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, nchannels, + row_stride, channel_stride_x, channel_stride_y, cudaStream_main); + CUDA_CHECK(cudaDeviceSynchronize()); break; case GGML_TYPE_Q2_K: - ggml_mul_mat_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + ggml_mul_mat_q2_K_q8_1_cuda( + src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, nchannels, + row_stride, channel_stride_x, channel_stride_y, cudaStream_main); break; case GGML_TYPE_Q3_K: - ggml_mul_mat_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + ggml_mul_mat_q3_K_q8_1_cuda( + src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, nchannels, + row_stride, channel_stride_x, channel_stride_y, cudaStream_main); break; case GGML_TYPE_Q4_K: - ggml_mul_mat_q4_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + ggml_mul_mat_q4_K_q8_1_cuda( + src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, nchannels, + row_stride, channel_stride_x, channel_stride_y, cudaStream_main); break; case GGML_TYPE_Q5_K: - ggml_mul_mat_q5_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + ggml_mul_mat_q5_K_q8_1_cuda( + src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, nchannels, + row_stride, channel_stride_x, channel_stride_y, cudaStream_main); break; case GGML_TYPE_Q6_K: - ggml_mul_mat_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, cudaStream_main); + ggml_mul_mat_q6_K_q8_1_cuda( + src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, i01_diff, ne11, padded_row_size, nrows_dst, nchannels, + row_stride, channel_stride_x, channel_stride_y, cudaStream_main); break; default: GGML_ASSERT(false); @@ -5505,6 +5681,7 @@ static int64_t get_row_rounding(ggml_type type) { } } +template inline void ggml_cuda_op_mul_mat_vec( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, @@ -5515,6 +5692,16 @@ inline void ggml_cuda_op_mul_mat_vec( GGML_ASSERT(dst_ddf_i != nullptr); const int64_t ne00 = src0->ne[0]; + const int64_t ne02 = src0->ne[2]; + + const int64_t ne10 = src1->ne[0]; + + const int64_t nb01 = src0->nb[1]; + const int64_t nb02 = src0->nb[2]; + + const int64_t nb11 = src1->nb[1]; + const int64_t nb12 = src1->nb[2]; + const int64_t nrows = i01_high - i01_low; #ifdef GGML_CUDA_FORCE_DMMV @@ -5540,45 +5727,50 @@ inline void ggml_cuda_op_mul_mat_vec( #endif // QK_K == 256 const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= MIN_CC_DP4A && mul_mat_vec_q_implemented; -#endif +#endif // GGML_CUDA_FORCE_DMMV if (use_mul_mat_vec_q) { - const int64_t padded_row_size = ne00 % MATRIX_ROW_PADDING == 0 ? - ne00 : ne00 - ne00 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING; + const int64_t padded_row_size = ne10 % MATRIX_ROW_PADDING == 0 ? + ne10 : ne10 - ne10 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING; size_t as; - void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*sizeof(block_q8_1)/QK8_1, &as); - quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, 1, padded_row_size, cudaStream_main); - + void * src1_q8_1 = ggml_cuda_pool_malloc(padded_row_size*ne02*sizeof(block_q8_1)/QK8_1, &as); + const int64_t row_stride_q = src1->backend == GGML_BACKEND_CPU ? ne10 : nb11 / sizeof(float); + const int64_t channel_stride_q = src1->backend == GGML_BACKEND_CPU ? ne10*1 : nb12 / sizeof(float); + quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne10, 1, padded_row_size, ne02, row_stride_q, channel_stride_q, cudaStream_main); + + const int row_stride_x = nb01 / ggml_type_size(src0->type); + const int channel_stride_x = nb02 / ggml_type_size(src0->type); + const int channel_stride_y = padded_row_size / QK8_1; switch (src0->type) { case GGML_TYPE_Q4_0: - mul_mat_vec_q4_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); + mul_mat_vec_q4_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main); break; case GGML_TYPE_Q4_1: - mul_mat_vec_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); + mul_mat_vec_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main); break; case GGML_TYPE_Q5_0: - mul_mat_vec_q5_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); + mul_mat_vec_q5_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main); break; case GGML_TYPE_Q5_1: - mul_mat_vec_q5_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); + mul_mat_vec_q5_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main); break; case GGML_TYPE_Q8_0: - mul_mat_vec_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); + mul_mat_vec_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main); break; case GGML_TYPE_Q2_K: - mul_mat_vec_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); + mul_mat_vec_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main); break; case GGML_TYPE_Q3_K: - mul_mat_vec_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); + mul_mat_vec_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main); break; case GGML_TYPE_Q4_K: - mul_mat_vec_q4_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); + mul_mat_vec_q4_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main); break; case GGML_TYPE_Q5_K: - mul_mat_vec_q5_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); + mul_mat_vec_q5_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main); break; case GGML_TYPE_Q6_K: - mul_mat_vec_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); + mul_mat_vec_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, ne02, row_stride_x, channel_stride_x, channel_stride_y, cudaStream_main); break; default: GGML_ASSERT(false); @@ -5587,6 +5779,8 @@ inline void ggml_cuda_op_mul_mat_vec( ggml_cuda_pool_free(src1_q8_1, as); } else { + GGML_ASSERT(buffers_contiguous || ne02 == 1); + // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics #ifdef GGML_CUDA_F16 size_t ash; @@ -5782,6 +5976,7 @@ inline void ggml_cuda_op_alibi( (void) src0_ddq_i; (void) src1_ddf_i; (void) i1; + (void) i02; } inline void ggml_cuda_op_diag_mask_inf( @@ -5984,7 +6179,16 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm if (src0_is_f32) { src0_ddf[id] = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_asf[id]); } else { - src0_ddq[id] = (char *) ggml_cuda_pool_malloc(row_diff*ne00 * src0_ts/src0_bs, &src0_asq[id]); + const int64_t nelements = row_diff*ne00; + const int64_t nelements_padded = ne00 % MATRIX_ROW_PADDING == 0 ? + nelements : nelements - ne00 % MATRIX_ROW_PADDING + MATRIX_ROW_PADDING; + const size_t size_padded = nelements_padded * src0_ts/src0_bs; + + src0_ddq[id] = (char *) ggml_cuda_pool_malloc(size_padded, &src0_asq[id]); + + if (nelements_padded > nelements) { + CUDA_CHECK(cudaMemsetAsync(src0_ddq[id], 0, size_padded, cudaStream_main)); + } } } @@ -6232,12 +6436,36 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te return false; } +void ggml_cuda_mul_mat_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ + GGML_ASSERT(src0->backend == GGML_BACKEND_GPU); + GGML_ASSERT(src1->backend == GGML_BACKEND_GPU); + GGML_ASSERT(dst->backend == GGML_BACKEND_GPU); + GGML_ASSERT(ggml_is_quantized(src0->type)); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const int64_t ne01 = src0->ne[1]; + + CUDA_CHECK(cudaSetDevice(g_main_device)); + cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device]; + + struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + char * src0_ddq = (char *) src0_extra->data_device[g_main_device]; + + struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; + + struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; + + ggml_cuda_op_mul_mat_q(src0, src1, dst, src0_ddq, nullptr, src1_ddf, dst_ddf, 0, 0, ne01, 0, cudaStream_main); + CUDA_CHECK(cudaGetLastError()); +} + void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation - GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); const int64_t ne00 = src0->ne[0]; @@ -6250,7 +6478,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device]; struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; - void * src0_ddq = src0_extra->data_device[g_main_device]; + char * src0_ddq = (char *) src0_extra->data_device[g_main_device]; struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; @@ -6258,14 +6486,19 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; - ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, cudaStream_main); + if (src0->type == GGML_TYPE_F16) { + ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, cudaStream_main); + } else if (ggml_is_quantized(src0->type)) { + ggml_cuda_op_mul_mat_vec(src0, src1, dst, src0_ddq, nullptr, src1_ddf, dst_ddf, 0, 0, ne01, 0, cudaStream_main); + } else { + GGML_ASSERT(false); + } } void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)); GGML_ASSERT(!ggml_is_permuted(src0)); GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); - GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src1->type == GGML_TYPE_F32); const int64_t ne00 = src0->ne[0]; @@ -6281,7 +6514,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1 cudaStream_t cudaStream_main = g_cudaStreams_main[g_main_device]; struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; - void * src0_ddq = src0_extra->data_device[g_main_device]; + char * src0_ddq = (char *) src0_extra->data_device[g_main_device]; struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; @@ -6289,15 +6522,22 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1 struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; - const int row_stride_x = nb01 / sizeof(half); - const int channel_stride_x = nb02 / sizeof(half); + if (src0->type == GGML_TYPE_F16) { + const int row_stride_x = nb01 / sizeof(half); + const int channel_stride_x = nb02 / sizeof(half); - ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, cudaStream_main); + ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, cudaStream_main); + } else if (ggml_is_quantized(src0->type)) { + ggml_cuda_op_mul_mat_vec(src0, src1, dst, src0_ddq, nullptr, src1_ddf, dst_ddf, 0, 0, ne01, 0, cudaStream_main); + } else { + GGML_ASSERT(false); + } } void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && + const bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU; + const bool src0_is_quantized = ggml_is_quantized(src0->type); if (all_on_device && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { ggml_cuda_mul_mat_vec_p021(src0, src1, dst); @@ -6307,7 +6547,7 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false); } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) { - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false); + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false); } else { int min_compute_capability = INT_MAX; for (int id = 0; id < g_device_count; ++id) { @@ -6318,7 +6558,11 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ } if (g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) { - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false); + if (all_on_device && src0->backend != GGML_BACKEND_GPU_SPLIT) { + ggml_cuda_mul_mat_nc(src0, src1, dst); + } else { + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false); + } } else { ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false); } @@ -6345,6 +6589,7 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; GGML_ASSERT(src0->ne[3] == 1); const int64_t nb00 = src0->nb[0]; @@ -6374,6 +6619,21 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { ggml_cpy_f32_f16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, cudaStream_main); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { + GGML_ASSERT(nb10 == sizeof(block_q8_0)); + + const size_t * op_params = (const size_t *) src1->op_params; + const size_t i_blck_0 = op_params[1]; + + if (ggml_is_contiguous(src1)) { + ggml_cpy_f32_q8_0_cuda( + src0_ddc, src1_ddc, i_blck_0, ne00, ne01, ne02, nb00, nb01, nb02, + ne00*sizeof(block_q8_0)/QK8_0, ne00*ne01*sizeof(block_q8_0)/QK8_0, cudaStream_main); + } else { + ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, i_blck_0, ne00, ne01, ne02, + nb00, nb01, nb02, nb11, nb12, cudaStream_main); + } + } else { GGML_ASSERT(false); } diff --git a/ggml.c b/ggml.c index 38b1155c13bc24..a0ac10c5941ec0 100644 --- a/ggml.c +++ b/ggml.c @@ -1099,11 +1099,19 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { assert(QK8_0 == 32); - assert(k % QK8_0 == 0); const int nb = k / QK8_0; block_q8_0 * restrict y = vy; + if (k % QK8_0 != 0) { + float x_end[QK8_0] = {0}; + memcpy(x_end, x + nb*QK8_0, sizeof(float) * (k % QK8_0)); + + block_q8_0 * y_end = y + nb; + + quantize_row_q8_0(x_end, y_end, QK8_0); + } + #if defined(__ARM_NEON) for (int i = 0; i < nb; i++) { float32x4_t srcv [8]; @@ -4355,8 +4363,13 @@ static inline bool ggml_is_matrix(const struct ggml_tensor * tensor) { static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); - return (t0->ne[0] == t1->ne[0]) && - (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable + const int64_t blck_size = ggml_blck_size(t0->type); + + const int64_t nblcks00_padded = (t0->ne[0] + blck_size - 1) / blck_size; + const int64_t nblcks10_padded = (t1->ne[0] + blck_size - 1) / blck_size; + + return (nblcks00_padded == nblcks10_padded) && // ensure same number of blocks after padding + (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable (t1->ne[3]%t0->ne[3] == 0); } @@ -6510,7 +6523,8 @@ static struct ggml_tensor * ggml_view_impl( struct ggml_tensor * a, int n_dims, const int64_t * ne, - size_t offset) { + size_t offset, + size_t i_blck) { bool is_node = false; @@ -6521,7 +6535,8 @@ static struct ggml_tensor * ggml_view_impl( struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, n_dims, ne, a, offset); ggml_format_name(result, "%s (view)", a->name); - ggml_set_op_params(result, &offset, sizeof(offset)); + size_t params[2] = {offset, i_blck}; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_VIEW; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -6538,7 +6553,7 @@ struct ggml_tensor * ggml_view_1d( int64_t ne0, size_t offset) { - struct ggml_tensor * result = ggml_view_impl(ctx, a, 1, &ne0, offset); + struct ggml_tensor * result = ggml_view_impl(ctx, a, 1, &ne0, offset, 0); return result; } @@ -6555,7 +6570,7 @@ struct ggml_tensor * ggml_view_2d( const int64_t ne[2] = { ne0, ne1 }; - struct ggml_tensor * result = ggml_view_impl(ctx, a, 2, ne, offset); + struct ggml_tensor * result = ggml_view_impl(ctx, a, 2, ne, offset, 0); result->nb[1] = nb1; result->nb[2] = result->nb[1]*ne1; @@ -6578,7 +6593,7 @@ struct ggml_tensor * ggml_view_3d( const int64_t ne[3] = { ne0, ne1, ne2 }; - struct ggml_tensor * result = ggml_view_impl(ctx, a, 3, ne, offset); + struct ggml_tensor * result = ggml_view_impl(ctx, a, 3, ne, offset, 0); result->nb[1] = nb1; result->nb[2] = nb2; @@ -6603,7 +6618,7 @@ struct ggml_tensor * ggml_view_4d( const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; - struct ggml_tensor * result = ggml_view_impl(ctx, a, 4, ne, offset); + struct ggml_tensor * result = ggml_view_impl(ctx, a, 4, ne, offset, 0); result->nb[1] = nb1; result->nb[2] = nb2; @@ -6612,6 +6627,42 @@ struct ggml_tensor * ggml_view_4d( return result; } +// ggml_view_blck_1d + +struct ggml_tensor * ggml_view_blck_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + size_t offset, + size_t i_blck) { + + struct ggml_tensor * result = ggml_view_impl(ctx, a, 1, &ne0, offset, i_blck); + + return result; +} + +// ggml_view_blck_2d + +struct ggml_tensor * ggml_view_blck_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + size_t nb1, + size_t offset, + size_t i_blck) { + + const int64_t ne[2] = { ne0, ne1 }; + + struct ggml_tensor * result = ggml_view_impl(ctx, a, 2, ne, offset, i_blck); + + result->nb[1] = nb1; + result->nb[2] = result->nb[1]*ne1; + result->nb[3] = result->nb[2]; + + return result; +} + // ggml_permute struct ggml_tensor * ggml_permute( @@ -8685,6 +8736,47 @@ static void ggml_compute_forward_dup_f32( } } } + } else if (type_traits[dst->type].from_float) { + GGML_ASSERT(ne00 == ne0); + GGML_ASSERT(ne01 == ne1); + GGML_ASSERT(ne02 == ne2); + GGML_ASSERT(ne03 == ne3); + + size_t blck_index_0 = 0; + if (dst->src[1]->op == GGML_OP_VIEW) { + const size_t * op_params = (const size_t *) dst->src[1]->op_params; + blck_index_0 = op_params[1]; + } + + ggml_from_float_t const quantize_row_q = type_traits[dst->type].from_float; + + for (int i03 = 0; i03 < ne03; i03++) { + for (int i02 = 0; i02 < ne02; i02++) { + for (int i01 = ir0; i01 < ir1; i01++) { + const char * src0_row_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + char * dst_row_ptr = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3; + size_t blck_index = blck_index_0; + + for (int i00 = 0; i00 < ne00; ++i00) { + char * dst_ptr = dst_row_ptr + + ggml_element_size(dst) * ((i00 + blck_index_0) / ggml_blck_size(dst->type)); + float * dst_tmp_ptr = (float *) (dst_ptr + ggml_element_size(dst)); + + if (blck_index == 0) { + memset(dst_tmp_ptr, 0, ggml_blck_size(dst->type)*sizeof(float)); + } + + dst_tmp_ptr[blck_index] = *((const float *) (src0_row_ptr + i00*nb00)); + + blck_index = (blck_index + 1) % ggml_blck_size(dst->type); + + if (blck_index == 0 || i00 == (ne00 - 1)) { + quantize_row_q(dst_tmp_ptr, dst_ptr, ggml_blck_size(dst->type)); + } + } + } + } + } } else { GGML_ASSERT(false); // TODO: implement } @@ -11207,6 +11299,9 @@ static void ggml_compute_forward_mul_mat( enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; + GGML_ASSERT(vec_dot != NULL); + GGML_ASSERT(from_float_to_vec_dot != NULL); + GGML_ASSERT(ne0 == ne01); GGML_ASSERT(ne1 == ne11); GGML_ASSERT(ne2 == ne12); @@ -11299,7 +11394,8 @@ static void ggml_compute_forward_mul_mat( if (params->type == GGML_TASK_INIT) { if (src1->type != vec_dot_type) { char * wdata = params->wdata; - const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type); + const size_t row_size = ggml_type_size(vec_dot_type)*(ne10 + ggml_blck_size(vec_dot_type) - 1) + / ggml_blck_size(vec_dot_type); for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { @@ -11319,7 +11415,8 @@ static void ggml_compute_forward_mul_mat( } const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; - const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type); + const size_t row_size = ggml_type_size(vec_dot_type)*(ne10 + ggml_blck_size(vec_dot_type) - 1) + / ggml_blck_size(vec_dot_type); const int64_t nr0 = ne01; // src0 rows const int64_t nr1 = ne11*ne12*ne13; // src1 rows diff --git a/ggml.h b/ggml.h index c936823d661404..f731915b0e3f38 100644 --- a/ggml.h +++ b/ggml.h @@ -1137,6 +1137,22 @@ extern "C" { size_t nb3, size_t offset); + GGML_API struct ggml_tensor * ggml_view_blck_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + size_t offset, + size_t i_blck); + + GGML_API struct ggml_tensor * ggml_view_blck_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + size_t nb1, // row stride in bytes + size_t offset, + size_t i_blck); + GGML_API struct ggml_tensor * ggml_permute( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/llama.cpp b/llama.cpp index 3413288fcb4a82..2cec8b66a03568 100644 --- a/llama.cpp +++ b/llama.cpp @@ -920,12 +920,29 @@ struct llama_hparams { return n_embd/n_gqa(); } - size_t kv_size() const { - size_t result = 2ull; + size_t kv_size(ggml_type type) const { + return kv_size_k(type) + kv_size_v(type); + } + + size_t kv_size_k(ggml_type type) const { + size_t result = 1ull; result *= (size_t) n_embd_gqa(); result *= (size_t) n_ctx; result *= (size_t) n_layer; - result *= sizeof(ggml_fp16_t); + result *= ggml_type_size(type); + result /= ggml_blck_size(type); + return result; + } + + size_t kv_size_v(ggml_type type) const { + const size_t row_padding = type == GGML_TYPE_Q8_0 ? 128 : 0; + + size_t result = 1ull; + result *= (size_t) n_embd_gqa(); + result *= (size_t) n_ctx + row_padding; + result *= (size_t) n_layer; + result *= ggml_type_size(type); + result /= ggml_blck_size(type); return result; } }; @@ -1150,10 +1167,26 @@ static bool llama_kv_cache_init( const int n_embd = hparams.n_embd_gqa(); const int n_layer = hparams.n_layer; + if (n_ctx % ggml_blck_size(wtype) != 0) { + LLAMA_LOG_ERROR("error: for KV type %s n_ctx must be a multiple of %d but received n_ctx=%d\n", + ggml_type_name(wtype), ggml_blck_size(wtype), n_ctx); + return false; + } + + if (n_embd % ggml_blck_size(wtype) != 0) { + LLAMA_LOG_ERROR("error: for KV type %s n_ctx must be a multiple of %d but received n_embd=%d\n", + ggml_type_name(wtype), ggml_blck_size(wtype), n_embd); + return false; + } + const int64_t n_mem = n_layer*n_ctx; const int64_t n_elements = n_embd*n_mem; - cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); + // if the KV cache is quantized we need a little extra space for each row to store the + // unquantized values between evals (this avoids precision loss when rebuilding the block) + const int64_t v_quant_buffer = wtype == GGML_TYPE_Q8_0 ? 128*n_layer*n_embd : 0; + + cache.buf.resize((2u*n_elements + v_quant_buffer)*ggml_type_size(wtype)/ggml_blck_size(wtype) + 2u*MB); cache.n = 0; struct ggml_init_params params; @@ -1169,7 +1202,7 @@ static bool llama_kv_cache_init( } cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); - cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements + v_quant_buffer); ggml_set_name(cache.k, "cache_k"); ggml_set_name(cache.v, "cache_v"); @@ -2075,15 +2108,13 @@ static void llm_load_tensors( // print memory requirements { - const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1; - // this is the total memory required to run the inference size_t mem_required = ctx_size + mmapped_size - vram_weights; // weights in VRAM not in memory // this is the memory required by one llama_state - const size_t mem_required_state = scale*hparams.kv_size(); + const size_t mem_required_state = hparams.kv_size(memory_type); LLAMA_LOG_INFO("%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__, mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); @@ -2107,7 +2138,7 @@ static void llm_load_tensors( LLAMA_LOG_INFO("%s: cannot offload v cache to GPU due to low VRAM option\n", __func__); } else { LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__); - vram_kv_cache += hparams.kv_size() / 2; + vram_kv_cache += hparams.kv_size_v(memory_type); } } if (n_gpu_layers > (int) hparams.n_layer + 2) { @@ -2115,7 +2146,7 @@ static void llm_load_tensors( LLAMA_LOG_WARN("%s: cannot offload k cache to GPU due to low VRAM option\n", __func__); } else { LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__); - vram_kv_cache += hparams.kv_size() / 2; + vram_kv_cache += hparams.kv_size_k(memory_type); } } #elif defined(GGML_USE_CLBLAST) @@ -2367,13 +2398,17 @@ static struct ggml_cgraph * llm_build_llama( offload_func_v(Vcur); ggml_set_name(Vcur, "Vcur"); - struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d( + ctx0, kv_self.k, N*n_embd_gqa, + (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past)/ggml_blck_size(kv_self.k->type)); offload_func_kq(k); ggml_set_name(k, "k"); - struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa, - ( n_ctx)*ggml_element_size(kv_self.v), - (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v)); + const int64_t v_row_size = kv_self.v->type == GGML_TYPE_Q8_0 ? n_ctx + 128 : n_ctx; + struct ggml_tensor * v = ggml_view_blck_2d(ctx0, kv_self.v, N, n_embd_gqa, + ( v_row_size)*ggml_element_size(kv_self.v)/ggml_blck_size(kv_self.v->type), + (il*v_row_size)*ggml_element_size(kv_self.v)*n_embd_gqa/ggml_blck_size(kv_self.v->type) + ggml_element_size(kv_self.v)*(n_past/ggml_blck_size(kv_self.v->type)), + n_past % ggml_blck_size(kv_self.v->type)); offload_func_v(v); ggml_set_name(v, "v"); @@ -2389,9 +2424,9 @@ static struct ggml_cgraph * llm_build_llama( struct ggml_tensor * K = ggml_view_3d(ctx0, kv_self.k, n_embd_head, n_past + N, n_head_kv, - ggml_element_size(kv_self.k)*n_embd_gqa, - ggml_element_size(kv_self.k)*n_embd_head, - ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); + ggml_element_size(kv_self.k)*n_embd_gqa/ggml_blck_size(kv_self.k->type), + ggml_element_size(kv_self.k)*n_embd_head/ggml_blck_size(kv_self.k->type), + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il/ggml_blck_size(kv_self.k->type)); offload_func_kq(K); ggml_set_name(K, "K"); @@ -2416,13 +2451,17 @@ static struct ggml_cgraph * llm_build_llama( offload_func_v(KQ_soft_max); ggml_set_name(KQ_soft_max, "KQ_soft_max"); + // split cached V into n_head heads + int64_t v_nelements_padded = n_past + N + ggml_blck_size(kv_self.v->type) - 1; + v_nelements_padded -= v_nelements_padded % ggml_blck_size(kv_self.v->type); + const int64_t v_row_size = kv_self.v->type == GGML_TYPE_Q8_0 ? n_ctx + 128 : n_ctx; struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v, - n_past + N, n_embd_head, n_head_kv, - ggml_element_size(kv_self.v)*n_ctx, - ggml_element_size(kv_self.v)*n_ctx*n_embd_head, - ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); + v_nelements_padded, n_embd_head, n_head_kv, + ggml_element_size(kv_self.v)*v_row_size/ggml_blck_size(kv_self.v->type), + ggml_element_size(kv_self.v)*v_row_size*n_embd_head/ggml_blck_size(kv_self.v->type), + ggml_element_size(kv_self.v)*v_row_size*n_embd_gqa*il/ggml_blck_size(kv_self.v->type)); offload_func_v(V); ggml_set_name(V, "V"); @@ -5371,9 +5410,9 @@ struct llama_context_params llama_context_default_params() { /*.rope_freq_scale =*/ 1.0f, /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, + /*.kv_type =*/ GGML_TYPE_Q8_0, /*.low_vram =*/ false, /*.mul_mat_q =*/ true, - /*.f16_kv =*/ true, /*.logits_all =*/ false, /*.vocab_only =*/ false, /*.use_mmap =*/ true, @@ -5448,8 +5487,6 @@ struct llama_model * llama_load_model_from_file( llama_model * model = new llama_model; - ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - unsigned cur_percentage = 0; if (params.progress_callback == NULL) { params.progress_callback_user_data = &cur_percentage; @@ -5468,7 +5505,7 @@ struct llama_model * llama_load_model_from_file( if (!llama_model_load(path_model, *model, params.n_ctx, params.n_batch, params.n_gpu_layers, params.main_gpu, params.tensor_split, params.mul_mat_q, params.rope_freq_base, params.rope_freq_scale, - params.low_vram, memory_type, params.use_mmap, params.use_mlock, params.vocab_only, + params.low_vram, params.kv_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback, params.progress_callback_user_data)) { LLAMA_LOG_ERROR("%s: failed to load model\n", __func__); delete model; @@ -5499,11 +5536,9 @@ struct llama_context * llama_new_context_with_model( ctx->rng = std::mt19937(params.seed); ctx->logits_all = params.logits_all; - ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - // reserve memory for context buffers if (!params.vocab_only) { - if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, ctx->model.hparams.n_ctx, params.n_gpu_layers)) { + if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, params.kv_type, ctx->model.hparams.n_ctx, params.n_gpu_layers)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; diff --git a/llama.h b/llama.h index 5b95aaa8776dd8..f724499ac692bf 100644 --- a/llama.h +++ b/llama.h @@ -140,10 +140,11 @@ extern "C" { // context pointer passed to the progress callback void * progress_callback_user_data; + enum ggml_type kv_type; // the type to use for the KV cache + // Keep the booleans together to avoid misalignment during copy-by-value. bool low_vram; // if true, reduce VRAM usage at the cost of performance bool mul_mat_q; // if true, use experimental mul_mat_q kernels - bool f16_kv; // use fp16 for KV cache bool logits_all; // the llama_eval() call computes all logits, not just the last one bool vocab_only; // only load the vocabulary, no weights bool use_mmap; // use mmap if possible diff --git a/run_with_preset.py b/run_with_preset.py index 8f90f52a9586e9..df416828ec1221 100755 --- a/run_with_preset.py +++ b/run_with_preset.py @@ -11,8 +11,8 @@ "batch-size", "cfg-negative-prompt", "cfg-scale", "chunks", "color", "ctx-size", "escape", "export", "file", "frequency-penalty", "grammar", "grammar-file", "hellaswag", "hellaswag-tasks", "ignore-eos", "in-prefix", "in-prefix-bos", "in-suffix", "instruct", - "interactive", "interactive-first", "keep", "logdir", "logit-bias", "lora", "lora-base", - "low-vram", "main-gpu", "memory-f32", "mirostat", "mirostat-ent", "mirostat-lr", "mlock", + "interactive", "interactive-first", "keep", "kv_type", "logdir", "logit-bias", "lora", + "lora-base", "low-vram", "main-gpu", "mirostat", "mirostat-ent", "mirostat-lr", "mlock", "model", "mtest", "multiline-input", "n-gpu-layers", "n-predict", "no-mmap", "no-mul-mat-q", "np-penalize-nl", "numa", "ppl-output-type", "ppl-stride", "presence-penalty", "prompt", "prompt-cache", "prompt-cache-all", "prompt-cache-ro", "random-prompt", "repeat-last-n",