diff --git a/examples/talk-llama/llama-adapter.cpp b/examples/talk-llama/llama-adapter.cpp index 9fd7edea332..8a080046313 100644 --- a/examples/talk-llama/llama-adapter.cpp +++ b/examples/talk-llama/llama-adapter.cpp @@ -1,5 +1,7 @@ #include "llama-adapter.h" +#include "llama-impl.h" +#include "llama-mmap.h" #include "llama-model.h" #include @@ -9,7 +11,7 @@ // vec -struct ggml_tensor * llama_control_vector::tensor_for(int il) const { +struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const { if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) { return nullptr; } @@ -17,7 +19,7 @@ struct ggml_tensor * llama_control_vector::tensor_for(int il) const { return tensors[il]; } -struct ggml_tensor * llama_control_vector::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const { +struct ggml_tensor * llama_adapter_cvec::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const { ggml_tensor * layer_dir = tensor_for(il); if (layer_dir != nullptr) { cur = ggml_add(ctx, cur, layer_dir); @@ -26,12 +28,12 @@ struct ggml_tensor * llama_control_vector::apply_to(struct ggml_context * ctx, s return cur; } -static bool llama_control_vector_init(struct llama_control_vector & cvec, const llama_model & model) { +bool llama_adapter_cvec::init(const llama_model & model) { const auto & hparams = model.hparams; - GGML_ASSERT(cvec.tensors.empty()); - GGML_ASSERT(cvec.ctxs.empty()); - GGML_ASSERT(cvec.bufs.empty()); + GGML_ASSERT(tensors.empty()); + GGML_ASSERT(ctxs.empty()); + GGML_ASSERT(bufs.empty()); // create a context for each buffer type std::map ctx_map; @@ -50,7 +52,7 @@ static bool llama_control_vector_init(struct llama_control_vector & cvec, const } ctx_map[buft] = ctx; - cvec.ctxs.emplace_back(ctx); + ctxs.emplace_back(ctx); return ctx; } @@ -59,21 +61,21 @@ static bool llama_control_vector_init(struct llama_control_vector & cvec, const }; // make tensors - cvec.tensors.reserve(hparams.n_layer); - cvec.tensors.push_back(nullptr); // there's never a tensor for layer 0 + tensors.reserve(hparams.n_layer); + tensors.push_back(nullptr); // there's never a tensor for layer 0 for (size_t il = 1; il < hparams.n_layer; il++) { - ggml_backend_buffer_type_t buft = llama_model_select_buft(model, il); + ggml_backend_buffer_type_t buft = model.select_buft(il); ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { LLAMA_LOG_ERROR("%s: failed to allocate context for control vector\n", __func__); return false; } ggml_tensor * tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); - cvec.tensors.push_back(tensor); + tensors.push_back(tensor); } // allocate tensors / buffers and zero - cvec.bufs.reserve(ctx_map.size()); + bufs.reserve(ctx_map.size()); for (auto it : ctx_map) { ggml_backend_buffer_type_t buft = it.first; ggml_context * ctx = it.second; @@ -83,14 +85,13 @@ static bool llama_control_vector_init(struct llama_control_vector & cvec, const return false; } ggml_backend_buffer_clear(buf, 0); - cvec.bufs.emplace_back(buf); + bufs.emplace_back(buf); } return true; } -int32_t llama_control_vector_apply( - struct llama_control_vector & cvec, +int32_t llama_adapter_cvec::apply( const llama_model & model, const float * data, size_t len, @@ -101,8 +102,8 @@ int32_t llama_control_vector_apply( if (data == nullptr) { // disable the current control vector (but leave allocated for later) - cvec.layer_start = -1; - cvec.layer_end = -1; + layer_start = -1; + layer_end = -1; return 0; } @@ -111,21 +112,21 @@ int32_t llama_control_vector_apply( return 1; } - if (cvec.tensors.empty()) { - if (!llama_control_vector_init(cvec, model)) { + if (tensors.empty()) { + if (!init(model)) { return 1; } } - cvec.layer_start = il_start; - cvec.layer_end = il_end; + layer_start = il_start; + layer_end = il_end; for (size_t il = 1; il < hparams.n_layer; il++) { - assert(cvec.tensors[il] != nullptr); + assert(tensors[il] != nullptr); const size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present if (off + n_embd <= len) { - ggml_backend_tensor_set(cvec.tensors[il], data + off, 0, n_embd * ggml_element_size(cvec.tensors[il])); + ggml_backend_tensor_set(tensors[il], data + off, 0, n_embd * ggml_element_size(tensors[il])); } } @@ -134,7 +135,7 @@ int32_t llama_control_vector_apply( // lora -llama_lora_weight * llama_lora_adapter::get_weight(struct ggml_tensor * w) { +llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor * w) { const std::string name(w->name); const auto pos = ab_map.find(name); @@ -145,11 +146,7 @@ llama_lora_weight * llama_lora_adapter::get_weight(struct ggml_tensor * w) { return nullptr; } -void llama_lora_adapter_free(struct llama_lora_adapter * adapter) { - delete adapter; -} - -static void llama_lora_adapter_init_impl(struct llama_model & model, const char * path_lora, struct llama_lora_adapter & adapter) { +static void llama_adapter_lora_init_impl(struct llama_model & model, const char * path_lora, struct llama_adapter_lora & adapter) { LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora); ggml_context * ctx_init; @@ -221,7 +218,7 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char }; // bundle lora_a and lora_b into pairs - std::map ab_map; + std::map ab_map; auto str_endswith = [](const std::string & str, const std::string & suffix) { return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; }; @@ -231,17 +228,21 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char if (str_endswith(name, ".lora_a")) { replace_all(name, ".lora_a", ""); if (ab_map.find(name) == ab_map.end()) { - ab_map[name] = llama_lora_weight(cur, nullptr); + ab_map[name] = llama_adapter_lora_weight(cur, nullptr); } else { ab_map[name].a = cur; } } else if (str_endswith(name, ".lora_b")) { replace_all(name, ".lora_b", ""); if (ab_map.find(name) == ab_map.end()) { - ab_map[name] = llama_lora_weight(nullptr, cur); + ab_map[name] = llama_adapter_lora_weight(nullptr, cur); } else { ab_map[name].b = cur; } + } else if (str_endswith(name, "_norm.weight")) { + // TODO: add support for norm vector + // for now, we don't really care because most adapters still work fine without it + continue; } else { throw std::runtime_error("LoRA tensor '" + name + "' has unexpected suffix"); } @@ -250,25 +251,33 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char // add tensors for (auto & it : ab_map) { const std::string & name = it.first; - llama_lora_weight & w = it.second; + llama_adapter_lora_weight & w = it.second; + bool is_token_embd = str_endswith(name, "token_embd.weight"); if (!w.a || !w.b) { throw std::runtime_error("LoRA tensor pair for '" + name + "' is missing one component"); } // device buft and device ctx - auto * model_tensor = llama_model_get_tensor(model, name.c_str()); + const auto * model_tensor = model.get_tensor(name.c_str()); if (!model_tensor) { - throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model"); + throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)"); } struct ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer)); // validate tensor shape - if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) { - throw std::runtime_error("tensor '" + name + "' has incorrect shape"); - } - if (w.a->ne[1] != w.b->ne[0]) { - throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)"); + if (is_token_embd) { + // expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd() + if (model_tensor->ne[0] != w.b->ne[1] || model_tensor->ne[1] != w.a->ne[1]) { + throw std::runtime_error("tensor '" + name + "' has incorrect shape (hint: maybe wrong base model?)"); + } + } else { + if (model_tensor->ne[0] != w.a->ne[0] || model_tensor->ne[1] != w.b->ne[1]) { + throw std::runtime_error("tensor '" + name + "' has incorrect shape (hint: maybe wrong base model?)"); + } + if (w.a->ne[1] != w.b->ne[0]) { + throw std::runtime_error("lora_a tensor is not transposed (hint: adapter from \"finetune\" example is no longer supported)"); + } } // save tensor to adapter @@ -276,7 +285,7 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b); ggml_set_name(tensor_a, w.a->name); ggml_set_name(tensor_b, w.b->name); - adapter.ab_map[name] = llama_lora_weight(tensor_a, tensor_b); + adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b); } // allocate tensors / buffers and zero @@ -318,11 +327,11 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2); } -struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model, const char * path_lora) { - struct llama_lora_adapter * adapter = new llama_lora_adapter(); +struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model, const char * path_lora) { + struct llama_adapter_lora * adapter = new llama_adapter_lora(); try { - llama_lora_adapter_init_impl(*model, path_lora, *adapter); + llama_adapter_lora_init_impl(*model, path_lora, *adapter); return adapter; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what()); @@ -332,3 +341,7 @@ struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model, return nullptr; } + +void llama_adapter_lora_free(struct llama_adapter_lora * adapter) { + delete adapter; +} diff --git a/examples/talk-llama/llama-adapter.h b/examples/talk-llama/llama-adapter.h index 5f1870cc8ad..603fa08f6d1 100644 --- a/examples/talk-llama/llama-adapter.h +++ b/examples/talk-llama/llama-adapter.h @@ -1,66 +1,74 @@ #pragma once -#include "llama-impl.h" -#include "llama-hparams.h" +#include "llama.h" #include "ggml-cpp.h" +#include #include #include +// TODO: pimpl + // // llama_adapter_cvec // -// TODO: rename to llama_adapter_cvec -struct llama_control_vector { - std::vector ctxs; - std::vector bufs; +struct llama_adapter_cvec { + struct ggml_tensor * tensor_for(int il) const; - std::vector tensors; // per layer + struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const; + + int32_t apply( + const llama_model & model, + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end); + +private: + bool init(const llama_model & model); int32_t layer_start = -1; int32_t layer_end = -1; - struct ggml_tensor * tensor_for(int il) const; + std::vector ctxs; + std::vector bufs; - struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const; + std::vector tensors; // per layer }; -int32_t llama_control_vector_apply( - struct llama_control_vector & cvec, - const llama_model & model, - const float * data, - size_t len, - int32_t n_embd, - int32_t il_start, - int32_t il_end); - // // llama_adapter_lora // -// TODO: rename to llama_adapter_lora_weight -struct llama_lora_weight { +struct llama_adapter_lora_weight { struct ggml_tensor * a = nullptr; struct ggml_tensor * b = nullptr; - llama_lora_weight() = default; - llama_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {} + // get actual scale based on rank and alpha + float get_scale(float alpha, float adapter_scale) const { + const float rank = (float) b->ne[0]; + const float scale = alpha ? adapter_scale * alpha / rank : adapter_scale; + return scale; + } + + llama_adapter_lora_weight() = default; + llama_adapter_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {} }; -// TODO: rename to llama_adapter_lora -struct llama_lora_adapter { +struct llama_adapter_lora { // map tensor name to lora_a_b - std::unordered_map ab_map; + std::unordered_map ab_map; std::vector ctxs; std::vector bufs; float alpha; - llama_lora_adapter() = default; - ~llama_lora_adapter() = default; + llama_adapter_lora() = default; + ~llama_adapter_lora() = default; - llama_lora_weight * get_weight(struct ggml_tensor * w); + llama_adapter_lora_weight * get_weight(struct ggml_tensor * w); }; diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index 007d79f8261..d7d277e7297 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -27,6 +27,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN2VL, "qwen2vl" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI3, "phi3" }, + { LLM_ARCH_PHIMOE, "phimoe" }, { LLM_ARCH_PLAMO, "plamo" }, { LLM_ARCH_CODESHELL, "codeshell" }, { LLM_ARCH_ORION, "orion" }, @@ -56,6 +57,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_NEMOTRON, "nemotron" }, { LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_RWKV6, "rwkv6" }, + { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, { LLM_ARCH_GRANITE, "granite" }, { LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_CHAMELEON, "chameleon" }, @@ -105,6 +107,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TIME_DECAY_EXTRA_DIM, "%s.time_decay_extra_dim" }, { LLM_KV_RESIDUAL_SCALE, "%s.residual_scale" }, { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, + { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -175,6 +178,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, + { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" }, { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, @@ -584,6 +588,27 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_PHIMOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, + { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_PLAMO, { @@ -1144,6 +1169,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_TIME_MIX_LERP_V, "blk.%d.time_mix_lerp_v" }, { LLM_TENSOR_TIME_MIX_LERP_R, "blk.%d.time_mix_lerp_r" }, { LLM_TENSOR_TIME_MIX_LERP_G, "blk.%d.time_mix_lerp_g" }, + { LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" }, { LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" }, { LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" }, { LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" }, @@ -1161,6 +1187,32 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "blk.%d.channel_mix_receptance" }, }, }, + { + LLM_ARCH_RWKV6QWEN2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_TIME_MIX_W1, "blk.%d.time_mix_w1" }, + { LLM_TENSOR_TIME_MIX_W2, "blk.%d.time_mix_w2" }, + { LLM_TENSOR_TIME_MIX_LERP_X, "blk.%d.time_mix_lerp_x" }, + { LLM_TENSOR_TIME_MIX_LERP_FUSED, "blk.%d.time_mix_lerp_fused" }, + { LLM_TENSOR_TIME_MIX_FIRST, "blk.%d.time_mix_first" }, + { LLM_TENSOR_TIME_MIX_DECAY, "blk.%d.time_mix_decay" }, + { LLM_TENSOR_TIME_MIX_DECAY_W1, "blk.%d.time_mix_decay_w1" }, + { LLM_TENSOR_TIME_MIX_DECAY_W2, "blk.%d.time_mix_decay_w2" }, + { LLM_TENSOR_TIME_MIX_KEY, "blk.%d.time_mix_key" }, + { LLM_TENSOR_TIME_MIX_VALUE, "blk.%d.time_mix_value" }, + { LLM_TENSOR_TIME_MIX_RECEPTANCE, "blk.%d.time_mix_receptance" }, + { LLM_TENSOR_TIME_MIX_GATE, "blk.%d.time_mix_gate" }, + { LLM_TENSOR_TIME_MIX_OUTPUT, "blk.%d.time_mix_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_GRANITE, { @@ -1343,6 +1395,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_TIME_MIX_LERP_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_LERP_G, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, + {LLM_TENSOR_TIME_MIX_LERP_FUSED, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_DECAY, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, {LLM_TENSOR_TIME_MIX_FIRST, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_RWKV_WKV6}}, {LLM_TENSOR_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index 45e458bb9cc..34984479045 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -31,6 +31,7 @@ enum llm_arch { LLM_ARCH_QWEN2VL, LLM_ARCH_PHI2, LLM_ARCH_PHI3, + LLM_ARCH_PHIMOE, LLM_ARCH_PLAMO, LLM_ARCH_CODESHELL, LLM_ARCH_ORION, @@ -60,6 +61,7 @@ enum llm_arch { LLM_ARCH_NEMOTRON, LLM_ARCH_EXAONE, LLM_ARCH_RWKV6, + LLM_ARCH_RWKV6QWEN2, LLM_ARCH_GRANITE, LLM_ARCH_GRANITE_MOE, LLM_ARCH_CHAMELEON, @@ -109,6 +111,7 @@ enum llm_kv { LLM_KV_TIME_DECAY_EXTRA_DIM, LLM_KV_RESIDUAL_SCALE, LLM_KV_EMBEDDING_SCALE, + LLM_KV_TOKEN_SHIFT_COUNT, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -173,6 +176,7 @@ enum llm_kv { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_RWKV, + LLM_KV_TOKENIZER_CHAT_TEMPLATE, LLM_KV_TOKENIZER_FIM_PRE_ID, LLM_KV_TOKENIZER_FIM_SUF_ID, LLM_KV_TOKENIZER_FIM_MID_ID, @@ -252,6 +256,7 @@ enum llm_tensor { LLM_TENSOR_TIME_MIX_LERP_V, LLM_TENSOR_TIME_MIX_LERP_R, LLM_TENSOR_TIME_MIX_LERP_G, + LLM_TENSOR_TIME_MIX_LERP_FUSED, LLM_TENSOR_TIME_MIX_FIRST, LLM_TENSOR_TIME_MIX_DECAY, LLM_TENSOR_TIME_MIX_DECAY_W1, diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index 44670d3d839..1347ec1560b 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -35,6 +35,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN }, { "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 }, { "phi3", LLM_CHAT_TEMPLATE_PHI_3 }, + { "phi4", LLM_CHAT_TEMPLATE_PHI_4 }, { "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 }, { "zephyr", LLM_CHAT_TEMPLATE_ZEPHYR }, { "monarch", LLM_CHAT_TEMPLATE_MONARCH }, @@ -73,7 +74,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return tmpl.find(haystack) != std::string::npos; }; if (tmpl_contains("<|im_start|>")) { - return LLM_CHAT_TEMPLATE_CHATML; + return tmpl_contains("<|im_sep|>") + ? LLM_CHAT_TEMPLATE_PHI_4 + : LLM_CHAT_TEMPLATE_CHATML; } else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) { if (tmpl_contains("[SYSTEM_PROMPT]")) { return LLM_CHAT_TEMPLATE_MISTRAL_V7; @@ -269,6 +272,14 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|assistant|>\n"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_PHI_4) { + // chatml template + for (auto message : chat) { + ss << "<|im_start|>" << message->role << "<|im_sep|>" << message->content << "<|im_end|>"; + } + if (add_ass) { + ss << "<|im_start|>assistant<|im_sep|>"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_FALCON_3) { // Falcon 3 for (auto message : chat) { diff --git a/examples/talk-llama/llama-chat.h b/examples/talk-llama/llama-chat.h index b8e94d9ef2b..3a4d07ce3de 100644 --- a/examples/talk-llama/llama-chat.h +++ b/examples/talk-llama/llama-chat.h @@ -15,6 +15,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN, LLM_CHAT_TEMPLATE_MISTRAL_V7, LLM_CHAT_TEMPLATE_PHI_3, + LLM_CHAT_TEMPLATE_PHI_4, LLM_CHAT_TEMPLATE_FALCON_3, LLM_CHAT_TEMPLATE_ZEPHYR, LLM_CHAT_TEMPLATE_MONARCH, diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 38a55fb2cd4..671d2a81ada 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -1,5 +1,8 @@ #include "llama-context.h" +#include "llama-impl.h" +#include "llama-mmap.h" + #include #include #include @@ -467,11 +470,12 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) { size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) { const auto & cparams = lctx.cparams; const auto & hparams = lctx.model.hparams; + const auto & vocab = lctx.model.vocab; const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max); const auto n_batch = cparams.n_batch; - const auto n_vocab = hparams.n_vocab; + const auto n_vocab = vocab.n_tokens(); const auto n_embd = hparams.n_embd; // TODO: use a per-batch flag for logits presence instead @@ -504,7 +508,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) { auto * buft = ggml_backend_cpu_buffer_type(); // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory - auto * output_dev = lctx.model.dev_output.dev; + auto * output_dev = lctx.model.dev_output(); auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr; if (output_dev_host_buft) { buft = output_dev_host_buft; @@ -538,7 +542,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) { void llama_output_reorder(struct llama_context & ctx) { std::vector & out_ids = ctx.sbatch.out_ids; if (!out_ids.empty()) { - const uint32_t n_vocab = ctx.model.hparams.n_vocab; + const uint32_t n_vocab = ctx.model.vocab.n_tokens(); const uint32_t n_embd = ctx.model.hparams.n_embd; const int32_t n_outputs = ctx.n_outputs; @@ -722,7 +726,7 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs)); } - return ctx->logits + j*ctx->model.hparams.n_vocab; + return ctx->logits + j*ctx->model.vocab.n_tokens(); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG @@ -882,7 +886,7 @@ struct llama_data_write { } void write_logits(const struct llama_context * ctx) { - const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_vocab); + const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.vocab.n_tokens()); write(&logits_size, sizeof(logits_size)); diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index 0d163c47090..a9268b29209 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -22,12 +22,12 @@ struct llama_context { const struct llama_model & model; - struct llama_cparams cparams; - struct llama_sbatch sbatch; // TODO: revisit if needed - struct llama_kv_cache kv_self; - struct llama_control_vector cvec; + struct llama_cparams cparams; + struct llama_sbatch sbatch; // TODO: revisit if needed + struct llama_kv_cache kv_self; + struct llama_adapter_cvec cvec; - std::unordered_map lora_adapters; + std::unordered_map lora; std::vector backends; std::vector> set_n_threads_fns; diff --git a/examples/talk-llama/llama-grammar.cpp b/examples/talk-llama/llama-grammar.cpp index 186dc9a25cf..bebe4e9a320 100644 --- a/examples/talk-llama/llama-grammar.cpp +++ b/examples/talk-llama/llama-grammar.cpp @@ -1092,9 +1092,9 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ for (size_t i = 0; i < cur_p->size; ++i) { const llama_token id = cur_p->data[i].id; - const std::string & piece = grammar.vocab->cache_token_to_piece.at(id); + const std::string & piece = grammar.vocab->token_to_piece(id); - if (llama_token_is_eog_impl(*grammar.vocab, id)) { + if (grammar.vocab->is_eog(id)) { if (!allow_eog) { cur_p->data[i].logit = -INFINITY; } @@ -1115,7 +1115,7 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { GGML_ASSERT(grammar.vocab != nullptr); - if (llama_token_is_eog_impl(*grammar.vocab, token)) { + if (grammar.vocab->is_eog(token)) { for (const auto & stack : grammar.stacks) { if (stack.empty()) { return; @@ -1124,7 +1124,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token GGML_ABORT("fatal error"); } - const std::string & piece = grammar.vocab->cache_token_to_piece.at(token); + const std::string & piece = grammar.vocab->token_to_piece(token); // Note terminating 0 in decoded string const auto decoded = decode_utf8(piece, grammar.partial_utf8); diff --git a/examples/talk-llama/llama-hparams.cpp b/examples/talk-llama/llama-hparams.cpp index c40534696b6..ea87b2953d9 100644 --- a/examples/talk-llama/llama-hparams.cpp +++ b/examples/talk-llama/llama-hparams.cpp @@ -52,7 +52,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { uint32_t llama_hparams::n_embd_k_s() const { if (wkv_head_size != 0) { // for RWKV models - return 2 * n_embd; + return token_shift_count * n_embd; } // TODO: maybe support other convolution strides than 1 diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index a29f20ec496..1fe45410371 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -30,7 +30,6 @@ struct llama_hparams { bool use_par_res; bool swin_norm; - uint32_t n_vocab = 0; uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; uint32_t n_embd_features = 0; @@ -41,7 +40,6 @@ struct llama_hparams { uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head uint32_t n_expert = 0; uint32_t n_expert_used = 0; - uint32_t n_vocab_type = 0; // for BERT-style token types uint32_t n_rel_attn_bkts = 0; // for WavTokenizer @@ -76,6 +74,7 @@ struct llama_hparams { uint32_t time_mix_extra_dim = 0; uint32_t time_decay_extra_dim = 0; uint32_t wkv_head_size = 0; + uint32_t token_shift_count = 2; float rope_attn_factor = 1.0f; float rope_freq_base_train; diff --git a/examples/talk-llama/llama-impl.cpp b/examples/talk-llama/llama-impl.cpp index a05ba4f635c..6ec709dd323 100644 --- a/examples/talk-llama/llama-impl.cpp +++ b/examples/talk-llama/llama-impl.cpp @@ -1,5 +1,6 @@ #include "llama-impl.h" +#include "gguf.h" #include "llama.h" #include @@ -138,7 +139,7 @@ std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { { const enum gguf_type arr_type = gguf_get_arr_type(ctx_gguf, i); int arr_n = gguf_get_arr_n(ctx_gguf, i); - const void * data = gguf_get_arr_data(ctx_gguf, i); + const void * data = arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx_gguf, i); std::stringstream ss; ss << "["; for (int j = 0; j < arr_n; j++) { diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp index 90b6c56ed06..feffdf0de52 100644 --- a/examples/talk-llama/llama-kv-cache.cpp +++ b/examples/talk-llama/llama-kv-cache.cpp @@ -79,7 +79,7 @@ bool llama_kv_cache_init( ggml_backend_buffer_type_t buft; if (offload) { - auto * dev = model.dev_layer.at(i).dev; + auto * dev = model.dev_layer(i); buft = ggml_backend_dev_buffer_type(dev); } else { buft = ggml_backend_cpu_buffer_type(); diff --git a/examples/talk-llama/llama-mmap.cpp b/examples/talk-llama/llama-mmap.cpp index a8cb9439b6b..57c6e4f510f 100644 --- a/examples/talk-llama/llama-mmap.cpp +++ b/examples/talk-llama/llama-mmap.cpp @@ -35,7 +35,7 @@ // TODO: consider moving to llama-impl.h if needed in more places #if defined(_WIN32) -std::string llama_format_win_err(DWORD err) { +static std::string llama_format_win_err(DWORD err) { LPSTR buf; size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL); diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index 7743b46522c..53175f0e069 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -7,6 +7,10 @@ #include #include +static const size_t kiB = 1024; +static const size_t MiB = 1024*kiB; +static const size_t GiB = 1024*MiB; + const char * llama_file_version_name(llama_fver version) { switch (version) { case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)"; @@ -17,8 +21,51 @@ const char * llama_file_version_name(llama_fver version) { return "unknown"; } +static std::string llama_model_ftype_name(llama_ftype ftype) { + if (ftype & LLAMA_FTYPE_GUESSED) { + return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)"; + } + + switch (ftype) { + case LLAMA_FTYPE_ALL_F32: return "all F32"; + case LLAMA_FTYPE_MOSTLY_F16: return "F16"; + case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; + case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; + case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; + case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; + case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; + case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; + case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large"; + case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "Q4_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small"; + case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium"; + case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; + case LLAMA_FTYPE_MOSTLY_TQ1_0: return "TQ1_0 - 1.69 bpw ternary"; + case LLAMA_FTYPE_MOSTLY_TQ2_0: return "TQ2_0 - 2.06 bpw ternary"; + case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ2_M: return "IQ2_M - 2.7 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return "IQ3_XXS - 3.0625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ1_S: return "IQ1_S - 1.5625 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; + case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; + + default: return "unknown, may not work"; + } +} + namespace GGUFMeta { - template + template struct GKV_Base_Type { static constexpr gguf_type gt = gt_; @@ -60,10 +107,11 @@ namespace GGUFMeta { public: static constexpr gguf_type gt = GGUF_TYPE_ARRAY; static ArrayInfo getter(const gguf_context *ctx, const int k) { + const enum gguf_type arr_type = gguf_get_arr_type(ctx, k); return ArrayInfo { - gguf_get_arr_type(ctx, k), + arr_type, size_t(gguf_get_arr_n(ctx, k)), - gguf_get_arr_data(ctx, k), + arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx, k), }; } }; @@ -553,7 +601,7 @@ llama_model_loader::llama_model_loader(const std::string & fname, bool use_mmap, const enum gguf_type type = gguf_get_kv_type(meta.get(), i); const std::string type_name = type == GGUF_TYPE_ARRAY - ? format("%s[%s,%d]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta.get(), i)), gguf_get_arr_n(meta.get(), i)) + ? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta.get(), i)), gguf_get_arr_n(meta.get(), i)) : gguf_type_name(type); std::string value = gguf_kv_to_str(meta.get(), i); @@ -1008,3 +1056,17 @@ bool llama_model_loader::load_all_data( return true; } + +std::string llama_model_loader::ftype_name() const { + return llama_model_ftype_name(ftype); +} + +void llama_model_loader::print_info() const { + LLAMA_LOG_INFO("%s: file format = %s\n", __func__, llama_file_version_name(fver)); + LLAMA_LOG_INFO("%s: file type = %s\n", __func__, llama_model_ftype_name(ftype).c_str()); + if (n_bytes < GiB) { + LLAMA_LOG_INFO("%s: file size = %.2f MiB (%.2f BPW) \n", __func__, n_bytes/1024.0/1024.0, n_bytes*8.0/n_elements); + } else { + LLAMA_LOG_INFO("%s: file size = %.2f GiB (%.2f BPW) \n", __func__, n_bytes/1024.0/1024.0/1024.0, n_bytes*8.0/n_elements); + } +} diff --git a/examples/talk-llama/llama-model-loader.h b/examples/talk-llama/llama-model-loader.h index 1ec47819567..b63d158d982 100644 --- a/examples/talk-llama/llama-model-loader.h +++ b/examples/talk-llama/llama-model-loader.h @@ -155,4 +155,8 @@ struct llama_model_loader { llama_mlocks * lmlocks, llama_progress_callback progress_callback, void * progress_callback_user_data); + + std::string ftype_name() const; + + void print_info() const; }; diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index 7deb3683bbc..f90f5e74607 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -1,128 +1,85 @@ #include "llama-model.h" #include "llama-impl.h" +#include "llama-mmap.h" #include "llama-model-loader.h" -#include "unicode.h" // TODO: remove +#include "ggml-cpp.h" #include #include +#include #include +#include #include #include -static const size_t kiB = 1024; -static const size_t MiB = 1024*kiB; -static const size_t GiB = 1024*MiB; - const char * llm_type_name(llm_type type) { switch (type) { - case MODEL_14M: return "14M"; - case MODEL_17M: return "17M"; - case MODEL_22M: return "22M"; - case MODEL_33M: return "33M"; - case MODEL_60M: return "60M"; - case MODEL_70M: return "70M"; - case MODEL_80M: return "80M"; - case MODEL_109M: return "109M"; - case MODEL_137M: return "137M"; - case MODEL_160M: return "160M"; - case MODEL_220M: return "220M"; - case MODEL_250M: return "250M"; - case MODEL_270M: return "270M"; - case MODEL_335M: return "335M"; - case MODEL_410M: return "410M"; - case MODEL_450M: return "450M"; - case MODEL_770M: return "770M"; - case MODEL_780M: return "780M"; - case MODEL_0_5B: return "0.5B"; - case MODEL_1B: return "1B"; - case MODEL_1_3B: return "1.3B"; - case MODEL_1_4B: return "1.4B"; - case MODEL_1_5B: return "1.5B"; - case MODEL_1_6B: return "1.6B"; - case MODEL_2B: return "2B"; - case MODEL_2_8B: return "2.8B"; - case MODEL_3B: return "3B"; - case MODEL_4B: return "4B"; - case MODEL_6B: return "6B"; - case MODEL_6_9B: return "6.9B"; - case MODEL_7B: return "7B"; - case MODEL_8B: return "8B"; - case MODEL_9B: return "9B"; - case MODEL_11B: return "11B"; - case MODEL_12B: return "12B"; - case MODEL_13B: return "13B"; - case MODEL_14B: return "14B"; - case MODEL_15B: return "15B"; - case MODEL_16B: return "16B"; - case MODEL_20B: return "20B"; - case MODEL_30B: return "30B"; - case MODEL_32B: return "32B"; - case MODEL_34B: return "34B"; - case MODEL_35B: return "35B"; - case MODEL_40B: return "40B"; - case MODEL_65B: return "65B"; - case MODEL_70B: return "70B"; - case MODEL_236B: return "236B"; - case MODEL_314B: return "314B"; - case MODEL_671B: return "671B"; - case MODEL_SMALL: return "0.1B"; - case MODEL_MEDIUM: return "0.4B"; - case MODEL_LARGE: return "0.8B"; - case MODEL_XL: return "1.5B"; - case MODEL_A1_7B: return "A1.7B"; - case MODEL_A2_7B: return "A2.7B"; - case MODEL_8x7B: return "8x7B"; - case MODEL_8x22B: return "8x22B"; - case MODEL_16x12B: return "16x12B"; - case MODEL_10B_128x3_66B: return "10B+128x3.66B"; - case MODEL_57B_A14B: return "57B.A14B"; - case MODEL_27B: return "27B"; - default: return "?B"; - } -} - -static std::string llama_model_ftype_name(llama_ftype ftype) { - if (ftype & LLAMA_FTYPE_GUESSED) { - return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)"; - } - - switch (ftype) { - case LLAMA_FTYPE_ALL_F32: return "all F32"; - case LLAMA_FTYPE_MOSTLY_F16: return "F16"; - case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; - case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; - case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; - case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; - case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; - case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; - case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large"; - case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "Q4_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small"; - case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; - case LLAMA_FTYPE_MOSTLY_TQ1_0: return "TQ1_0 - 1.69 bpw ternary"; - case LLAMA_FTYPE_MOSTLY_TQ2_0: return "TQ2_0 - 2.06 bpw ternary"; - case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ2_M: return "IQ2_M - 2.7 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return "IQ3_XXS - 3.0625 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ1_S: return "IQ1_S - 1.5625 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; - case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; - - default: return "unknown, may not work"; + case LLM_TYPE_14M: return "14M"; + case LLM_TYPE_17M: return "17M"; + case LLM_TYPE_22M: return "22M"; + case LLM_TYPE_33M: return "33M"; + case LLM_TYPE_60M: return "60M"; + case LLM_TYPE_70M: return "70M"; + case LLM_TYPE_80M: return "80M"; + case LLM_TYPE_109M: return "109M"; + case LLM_TYPE_137M: return "137M"; + case LLM_TYPE_160M: return "160M"; + case LLM_TYPE_220M: return "220M"; + case LLM_TYPE_250M: return "250M"; + case LLM_TYPE_270M: return "270M"; + case LLM_TYPE_335M: return "335M"; + case LLM_TYPE_410M: return "410M"; + case LLM_TYPE_450M: return "450M"; + case LLM_TYPE_770M: return "770M"; + case LLM_TYPE_780M: return "780M"; + case LLM_TYPE_0_5B: return "0.5B"; + case LLM_TYPE_1B: return "1B"; + case LLM_TYPE_1_3B: return "1.3B"; + case LLM_TYPE_1_4B: return "1.4B"; + case LLM_TYPE_1_5B: return "1.5B"; + case LLM_TYPE_1_6B: return "1.6B"; + case LLM_TYPE_2B: return "2B"; + case LLM_TYPE_2_8B: return "2.8B"; + case LLM_TYPE_3B: return "3B"; + case LLM_TYPE_4B: return "4B"; + case LLM_TYPE_6B: return "6B"; + case LLM_TYPE_6_9B: return "6.9B"; + case LLM_TYPE_7B: return "7B"; + case LLM_TYPE_8B: return "8B"; + case LLM_TYPE_9B: return "9B"; + case LLM_TYPE_11B: return "11B"; + case LLM_TYPE_12B: return "12B"; + case LLM_TYPE_13B: return "13B"; + case LLM_TYPE_14B: return "14B"; + case LLM_TYPE_15B: return "15B"; + case LLM_TYPE_16B: return "16B"; + case LLM_TYPE_20B: return "20B"; + case LLM_TYPE_30B: return "30B"; + case LLM_TYPE_32B: return "32B"; + case LLM_TYPE_34B: return "34B"; + case LLM_TYPE_35B: return "35B"; + case LLM_TYPE_40B: return "40B"; + case LLM_TYPE_65B: return "65B"; + case LLM_TYPE_70B: return "70B"; + case LLM_TYPE_236B: return "236B"; + case LLM_TYPE_314B: return "314B"; + case LLM_TYPE_671B: return "671B"; + case LLM_TYPE_SMALL: return "0.1B"; + case LLM_TYPE_MEDIUM: return "0.4B"; + case LLM_TYPE_LARGE: return "0.8B"; + case LLM_TYPE_XL: return "1.5B"; + case LLM_TYPE_A1_7B: return "A1.7B"; + case LLM_TYPE_A2_7B: return "A2.7B"; + case LLM_TYPE_8x7B: return "8x7B"; + case LLM_TYPE_8x22B: return "8x22B"; + case LLM_TYPE_16x12B: return "16x12B"; + case LLM_TYPE_16x3_8B: return "16x3.8B"; + case LLM_TYPE_10B_128x3_66B: return "10B+128x3.66B"; + case LLM_TYPE_57B_A14B: return "57B.A14B"; + case LLM_TYPE_27B: return "27B"; + default: return "?B"; } } @@ -134,132 +91,301 @@ static const char * llama_expert_gating_func_name(llama_expert_gating_func_type } } -std::string llama_model_arch_name (const llama_model & model) { - return llm_arch_name(model.arch); -} +static const std::map LLAMA_ROPE_SCALING_TYPES = { + { LLAMA_ROPE_SCALING_TYPE_NONE, "none" }, + { LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" }, + { LLAMA_ROPE_SCALING_TYPE_YARN, "yarn" }, + { LLAMA_ROPE_SCALING_TYPE_LONGROPE, "longrope" }, +}; -std::string llama_model_type_name (const llama_model & model) { - return llm_type_name(model.type); -} +static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) { + for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) { + if (kv.second == name) { + return (llama_rope_scaling_type) kv.first; + } + } -std::string llama_model_ftype_name(const llama_model & model) { - return llama_model_ftype_name(model.ftype); + return LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; } -template -static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) { +// checks if the weight tensor can be used with the specified buffer type and device +static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) { + GGML_ASSERT(w != nullptr); + + if (op == GGML_OP_NONE) { + return true; + } + ggml_init_params params = { /*.mem_size =*/ ggml_tensor_overhead()*8, /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - - ggml_context_ptr ctx { ggml_init(params) }; - if (!ctx) { + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { throw std::runtime_error(format("failed to create ggml context")); } + ggml_context * ctx = ctx_ptr.get(); - ggml_backend_buffer_ptr buf { ggml_backend_buft_alloc_buffer(buft, 0) }; - ggml_tensor * op_tensor = fn(ctx.get()); - for (int i = 0; i < GGML_MAX_SRC; i++) { - if (op_tensor->src[i] != nullptr) { - assert(op_tensor->src[i]->buffer == nullptr); - op_tensor->src[i]->buffer = buf.get(); - } + ggml_tensor * op_tensor = nullptr; + + switch (op) { + case GGML_OP_GET_ROWS: + { + ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); + op_tensor = ggml_get_rows(ctx, w, b); + } break; + case GGML_OP_MUL_MAT: + { + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], 512, w->ne[2], w->ne[3]); + op_tensor = ggml_mul_mat(ctx, w, b); + } break; + case GGML_OP_MUL_MAT_ID: + { + int n_expert_used = hparams.n_expert_used; + ggml_tensor * b = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512); + ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512); + op_tensor = ggml_mul_mat_id(ctx, w, b, ids); + } break; + case GGML_OP_ADD: + { + ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); + op_tensor = ggml_add(ctx, a, w); + } break; + case GGML_OP_MUL: + { + ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); + op_tensor = ggml_mul(ctx, a, w); + } break; + case GGML_OP_DIV: + { + ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, w->ne[0]); + op_tensor = ggml_div(ctx, a, w); + } break; + case GGML_OP_ROPE: + { + int n_embd_head = hparams.n_embd_head_v; + int n_head = hparams.n_head(); + ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, 512); + ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); + op_tensor = ggml_rope_ext( + ctx, a, b, w, + 0, 0, 0, 0, 0, + 0, 0, 0, 0 + ); + + } break; + case GGML_OP_SSM_CONV: + { + // FIXME + ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789); + op_tensor = ggml_ssm_conv(ctx, conv_x, w); + } break; + case GGML_OP_SSM_SCAN: + { + // FIXME + const int64_t d_state = w->ne[0]; + const int64_t d_inner = w->ne[1]; + const int64_t n_seq_tokens = 512; + const int64_t n_seqs = 1; + ggml_tensor * s = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs); + ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); + ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); + op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C); + } break; + case GGML_OP_RWKV_WKV6: + { + // FIXME + const int64_t S = 123; + const int64_t H = 123; + const int64_t n_tokens = 123; + const int64_t n_seqs = 123; + ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * r = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * tf = w; + ggml_tensor * td = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H); + op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state); + } break; + case GGML_OP_IM2COL: + { + const int n_embd = hparams.n_embd; + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1); + op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16); + } break; + default: + GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name); } + // create a temporary dummy buffer for the weight so that supports_op can check the buffer type + GGML_ASSERT(w->buffer == nullptr); + w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor); + ggml_backend_buffer_free(w->buffer); + w->buffer = nullptr; return op_supported; } -template -static ggml_backend_buffer_type_t select_buft(const llama_model::buft_list_t & buft_list, const F & fn) { +// lists of buffer types used for each layer +using buft_list_t = std::vector>; + +// find the first buffer type in the list that can use the tensor +static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hparams, ggml_tensor * tensor, ggml_op op, const buft_list_t & buft_list) { + GGML_ASSERT(!buft_list.empty()); for (const auto & cur : buft_list) { ggml_backend_dev_t cur_dev = cur.first; ggml_backend_buffer_type_t cur_buft = cur.second; - if (buft_supported(cur_buft, cur_dev, fn)) { + if (weight_buft_supported(hparams, tensor, op, cur_buft, cur_dev)) { return cur_buft; } } - - throw std::runtime_error(format("no suitable buffer type found")); + return nullptr; } -ggml_backend_buffer_type_t llama_model_select_buft(const llama_model & model, int il) { - return select_buft( - *model.dev_layer.at(il).buft_list, - [&](ggml_context * ctx) { - ggml_tensor * cur = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.hparams.n_embd); - ggml_tensor * layer_dir = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model.hparams.n_embd); - return ggml_add(ctx, cur, layer_dir); - }); -} +// CPU: ACCEL -> CPU extra -> GPU host -> CPU +static buft_list_t make_cpu_buft_list(const std::vector & devices) { + buft_list_t buft_list; + + // add ACCEL buffer types + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { + auto * buft = ggml_backend_dev_buffer_type(dev); + // skip + if (buft != ggml_backend_cpu_buffer_type()) { + buft_list.emplace_back(dev, buft); + } + } + } -struct ggml_tensor * llama_model_get_tensor(const struct llama_model & model, const char * name) { - auto it = std::find_if(model.tensors_by_name.begin(), model.tensors_by_name.end(), - [name](const std::pair & it) { - return it.first == name; - }); - if (it == model.tensors_by_name.end()) { - return nullptr; + // add extra buffer types + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); + if (ggml_backend_dev_get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) { + buft_list.emplace_back(cpu_dev, *extra_bufts); + ++extra_bufts; + } } - return it->second; -} + // add a host buffer type + // storing the tensors in a host buffer is useful when the processing of large batches + // is offloaded to a GPU device, since it reduces the time spent on data transfers + // generally, this will be done using the first device in the list + // a better approach would be to handle this on a weight-by-weight basis using the offload_op + // function of the device to determine if it would benefit from being stored in a host buffer + for (auto * dev : devices) { + ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); + if (buft) { + buft_list.emplace_back(dev, buft); + break; + } + } -size_t llama_model_max_nodes(const llama_model & model) { - return std::max(8192, model.tensors_by_name.size()*5); -} + // add the CPU buffer type + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) { + buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev)); + } + } -static const std::map LLAMA_ROPE_SCALING_TYPES = { - { LLAMA_ROPE_SCALING_TYPE_NONE, "none" }, - { LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" }, - { LLAMA_ROPE_SCALING_TYPE_YARN, "yarn" }, - { LLAMA_ROPE_SCALING_TYPE_LONGROPE, "longrope" }, -}; + return buft_list; +} -static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) { - for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) { - if (kv.second == name) { - return (llama_rope_scaling_type) kv.first; +// GPU: split if LLAMA_SPLIT_MODE_ROW -> GPU +static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, enum llama_split_mode split_mode, const float * tensor_split) { + buft_list_t buft_list; + + // add the device split buffer type if requested and available + if (split_mode == LLAMA_SPLIT_MODE_ROW) { + ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); + auto ggml_backend_split_buffer_type_fn = (ggml_backend_split_buffer_type_t) + ggml_backend_reg_get_proc_address(reg, "ggml_backend_split_buffer_type"); + if (ggml_backend_split_buffer_type_fn) { + size_t dev_index = [&]() { + auto * reg = ggml_backend_dev_backend_reg(dev); + for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); ++i) { + if (ggml_backend_reg_dev_get(reg, i) == dev) { + return i; + } + } + throw std::runtime_error(format("device %s not found in its backend reg", ggml_backend_dev_name(dev))); + }(); + auto * buft = ggml_backend_split_buffer_type_fn(dev_index, tensor_split); + if (buft != nullptr) { + buft_list.emplace_back(dev, buft); + } } } - return LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; + // add the device default buffer type + buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev)); + + return buft_list; } -// NOTE: avoid ever using this except for building the token_to_piece caches -static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) { - std::string piece; - piece.resize(piece.capacity()); // using string internal cache - const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); - if (n_chars < 0) { - piece.resize(-n_chars); - int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); - GGML_ASSERT(check == -n_chars); - } - else { - piece.resize(n_chars); - } +struct llama_model::impl { + impl() {} + ~impl() {} + + uint64_t n_elements = 0; + + size_t n_bytes = 0; + + std::string desc_str; + + // model memory mapped files + llama_mmaps mappings; + + // objects representing data potentially being locked in memory + llama_mlocks mlock_bufs; + llama_mlocks mlock_mmaps; + + // contexts where the model tensors metadata is stored + std::vector ctxs; + + // the model memory buffers for the tensor data + std::vector bufs; + + buft_list_t cpu_buft_list; + std::map gpu_buft_list; + + struct layer_dev { + ggml_backend_dev_t dev; + buft_list_t * buft_list; + }; - return piece; + layer_dev dev_input = {}; + layer_dev dev_output = {}; + std::vector dev_layer; +}; + +llama_model::llama_model(const struct llama_model_params & params) : params(params), pimpl(std::make_unique()) { } -void llm_load_stats(llama_model_loader & ml, llama_model & model) { - model.n_elements = ml.n_elements; - model.n_bytes = ml.n_bytes; +llama_model::~llama_model() {} + +void llama_model::load_stats(llama_model_loader & ml) { + pimpl->n_elements = ml.n_elements; + pimpl->n_bytes = ml.n_bytes; } -void llm_load_arch(llama_model_loader & ml, llama_model & model) { - model.arch = ml.get_arch(); - if (model.arch == LLM_ARCH_UNKNOWN) { +void llama_model::load_arch(llama_model_loader & ml) { + arch = ml.get_arch(); + if (arch == LLM_ARCH_UNKNOWN) { throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'"); } } -void llm_load_hparams(llama_model_loader & ml, llama_model & model) { - auto & hparams = model.hparams; +void llama_model::load_hparams(llama_model_loader & ml) { const gguf_context * ctx = ml.meta.get(); // get metadata as string @@ -270,14 +396,11 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { } const char * name = gguf_get_key(ctx, i); const std::string value = gguf_kv_to_str(ctx, i); - model.gguf_kv.emplace(name, value); + gguf_kv.emplace(name, value); } // get general kv - ml.get_key(LLM_KV_GENERAL_NAME, model.name, false); - - // get hparams kv - ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab, false); + ml.get_key(LLM_KV_GENERAL_NAME, name, false); // everything past this point is not vocab-related if (hparams.vocab_only) { @@ -290,7 +413,7 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); - if (model.arch == LLM_ARCH_WAVTOKENIZER_DEC) { + if (arch == LLM_ARCH_WAVTOKENIZER_DEC) { ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features); ml.get_key(LLM_KV_POSNET_EMBEDDING_LENGTH, hparams.posnet.n_embd); @@ -363,7 +486,7 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); - if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_DECI || model.arch == LLM_ARCH_FALCON) { + if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) { if (hparams.n_rot != hparams.n_embd_head_k) { throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); } @@ -374,34 +497,36 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { hparams.n_embd_head_v = 0; } - using e_model = llm_type; // TMP + // for differentiating model types + uint32_t n_vocab = 0; + ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false); // arch-specific KVs - switch (model.arch) { + switch (arch) { case LLM_ARCH_LLAMA: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); if (hparams.n_expert == 8) { switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_8x7B; break; - case 56: model.type = e_model::MODEL_8x22B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_8x7B; break; + case 56: type = LLM_TYPE_8x22B; break; + default: type = LLM_TYPE_UNKNOWN; } } else { switch (hparams.n_layer) { - case 16: model.type = e_model::MODEL_1B; break; // Llama 3.2 1B - case 22: model.type = e_model::MODEL_1B; break; - case 26: model.type = e_model::MODEL_3B; break; - case 28: model.type = e_model::MODEL_3B; break; // Llama 3.2 3B + case 16: type = LLM_TYPE_1B; break; // Llama 3.2 1B + case 22: type = LLM_TYPE_1B; break; + case 26: type = LLM_TYPE_3B; break; + case 28: type = LLM_TYPE_3B; break; // Llama 3.2 3B // granite uses a vocab with len 49152 - case 32: model.type = hparams.n_vocab == 49152 ? e_model::MODEL_3B : (hparams.n_vocab < 40000 ? e_model::MODEL_7B : e_model::MODEL_8B); break; - case 36: model.type = e_model::MODEL_8B; break; // granite - case 40: model.type = e_model::MODEL_13B; break; - case 48: model.type = e_model::MODEL_34B; break; - case 60: model.type = e_model::MODEL_30B; break; - case 80: model.type = hparams.n_head() == hparams.n_head_kv() ? e_model::MODEL_65B : e_model::MODEL_70B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = n_vocab == 49152 ? LLM_TYPE_3B : (n_vocab < 40000 ? LLM_TYPE_7B : LLM_TYPE_8B); break; + case 36: type = LLM_TYPE_8B; break; // granite + case 40: type = LLM_TYPE_13B; break; + case 48: type = LLM_TYPE_34B; break; + case 60: type = LLM_TYPE_30B; break; + case 80: type = hparams.n_head() == hparams.n_head_kv() ? LLM_TYPE_65B : LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; } } } break; @@ -409,33 +534,33 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 80: model.type = e_model::MODEL_70B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_7B; break; + case 80: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_MINICPM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); - ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); switch (hparams.n_layer) { - case 52: model.type = e_model::MODEL_1B; break; - case 40: model.type = e_model::MODEL_2B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 52: type = LLM_TYPE_1B; break; + case 40: type = LLM_TYPE_2B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_MINICPM3: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); switch (hparams.n_layer) { - case 62: model.type = e_model::MODEL_4B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 62: type = LLM_TYPE_4B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_GROK: @@ -443,8 +568,8 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 64: model.type = e_model::MODEL_314B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 64: type = LLM_TYPE_314B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_FALCON: @@ -452,21 +577,21 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 60: model.type = e_model::MODEL_40B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_7B; break; + case 60: type = LLM_TYPE_40B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_BAICHUAN: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_13B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; } - if (model.type == e_model::MODEL_13B) { + if (type == LLM_TYPE_13B) { // TODO: become GGUF KV parameter hparams.f_max_alibi_bias = 8.0f; } @@ -475,19 +600,19 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_1B; break; - case 36: model.type = e_model::MODEL_3B; break; - case 42: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_15B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 24: type = LLM_TYPE_1B; break; + case 36: type = LLM_TYPE_3B; break; + case 42: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_15B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_REFACT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_1B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_1B; break; + default: type = LLM_TYPE_UNKNOWN; } // TODO: become GGUF KV parameter @@ -497,48 +622,45 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); switch (hparams.n_layer) { case 3: - model.type = e_model::MODEL_17M; break; // bge-micro + type = LLM_TYPE_17M; break; // bge-micro case 6: - model.type = e_model::MODEL_22M; break; // MiniLM-L6 + type = LLM_TYPE_22M; break; // MiniLM-L6 case 12: switch (hparams.n_embd) { - case 384: model.type = e_model::MODEL_33M; break; // MiniLM-L12, bge-small - case 768: model.type = e_model::MODEL_109M; break; // bge-base - default: model.type = e_model::MODEL_UNKNOWN; + case 384: type = LLM_TYPE_33M; break; // MiniLM-L12, bge-small + case 768: type = LLM_TYPE_109M; break; // bge-base + default: type = LLM_TYPE_UNKNOWN; } break; case 24: - model.type = e_model::MODEL_335M; break; // bge-large - default: model.type = e_model::MODEL_UNKNOWN; + type = LLM_TYPE_335M; break; // bge-large + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_JINA_BERT_V2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); hparams.f_max_alibi_bias = 8.0f; switch (hparams.n_layer) { - case 4: model.type = e_model::MODEL_33M; break; // jina-embeddings-small - case 12: model.type = e_model::MODEL_137M; break; // jina-embeddings-base - default: model.type = e_model::MODEL_UNKNOWN; + case 4: type = LLM_TYPE_33M; break; // jina-embeddings-small + case 12: type = LLM_TYPE_137M; break; // jina-embeddings-base + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_NOMIC_BERT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); if (hparams.n_layer == 12 && hparams.n_embd == 768) { - model.type = e_model::MODEL_137M; + type = LLM_TYPE_137M; } } break; case LLM_ARCH_BLOOM: @@ -546,14 +668,14 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_1B; break; + case 24: type = LLM_TYPE_1B; break; case 30: switch (hparams.n_embd) { - case 2560: model.type = e_model::MODEL_3B; break; - case 4096: model.type = e_model::MODEL_7B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; } break; - default: model.type = e_model::MODEL_UNKNOWN; + default: type = LLM_TYPE_UNKNOWN; } // TODO: become GGUF KV parameter @@ -566,9 +688,9 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 48: model.type = e_model::MODEL_30B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_7B; break; + case 48: type = LLM_TYPE_30B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_STABLELM: @@ -576,10 +698,10 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_1B; break; - case 32: model.type = e_model::MODEL_3B; break; - case 40: model.type = e_model::MODEL_12B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 24: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_12B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_QWEN: @@ -587,9 +709,9 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_13B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_QWEN2VL: @@ -601,27 +723,27 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 24: model.type = hparams.n_embd == 1024 ? e_model::MODEL_0_5B : e_model::MODEL_1B; break; - case 28: model.type = hparams.n_embd == 1536 ? e_model::MODEL_1_5B : e_model::MODEL_7B; break; - case 32: model.type = e_model::MODEL_7B; break; - case 36: model.type = e_model::MODEL_3B; break; - case 40: model.type = hparams.n_head() == 20 ? e_model::MODEL_4B : e_model::MODEL_13B; break; - case 48: model.type = e_model::MODEL_14B; break; - case 64: model.type = e_model::MODEL_32B; break; - case 80: model.type = e_model::MODEL_70B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break; + case 28: type = hparams.n_embd == 1536 ? LLM_TYPE_1_5B : LLM_TYPE_7B; break; + case 32: type = LLM_TYPE_7B; break; + case 36: type = LLM_TYPE_3B; break; + case 40: type = hparams.n_head() == 20 ? LLM_TYPE_4B : LLM_TYPE_13B; break; + case 48: type = LLM_TYPE_14B; break; + case 64: type = LLM_TYPE_32B; break; + case 80: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_QWEN2MOE: { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_A2_7B; break; - case 28: model.type = e_model::MODEL_57B_A14B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 24: type = LLM_TYPE_A2_7B; break; + case 28: type = LLM_TYPE_57B_A14B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_PHI2: @@ -629,9 +751,9 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_1B; break; - case 32: model.type = e_model::MODEL_3B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 24: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_PHI3: @@ -639,10 +761,10 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_1B; break; - case 32: model.type = e_model::MODEL_3B; break; - case 40: model.type = e_model::MODEL_14B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 24: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; } // for backward compatibility ; see: https://github.com/ggerganov/llama.cpp/pull/8931 @@ -661,32 +783,41 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { throw std::runtime_error("invalid value for sliding_window"); } } break; + case LLM_ARCH_PHIMOE: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_16x3_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_PLAMO: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 40: model.type = e_model::MODEL_13B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_GPT2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 12: model.type = e_model::MODEL_SMALL; break; - case 24: model.type = e_model::MODEL_MEDIUM; break; - case 36: model.type = e_model::MODEL_LARGE; break; - case 48: model.type = e_model::MODEL_XL; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 12: type = LLM_TYPE_SMALL; break; + case 24: type = LLM_TYPE_MEDIUM; break; + case 36: type = LLM_TYPE_LARGE; break; + case 48: type = LLM_TYPE_XL; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_CODESHELL: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 42: model.type = e_model::MODEL_7B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 42: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_ORION: @@ -694,17 +825,17 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 40: model.type = e_model::MODEL_14B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 40: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_INTERNLM2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 48: model.type = e_model::MODEL_20B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_7B; break; + case 48: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_GEMMA: @@ -712,37 +843,37 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 18: model.type = e_model::MODEL_2B; break; - case 28: model.type = e_model::MODEL_7B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 18: type = LLM_TYPE_2B; break; + case 28: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_GEMMA2: { hparams.n_swa = 4096; // default value of gemma 2 - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); - ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); hparams.attn_soft_cap = true; switch (hparams.n_layer) { - case 26: model.type = e_model::MODEL_2B; break; - case 42: model.type = e_model::MODEL_9B; break; - case 46: model.type = e_model::MODEL_27B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 26: type = LLM_TYPE_2B; break; + case 42: type = LLM_TYPE_9B; break; + case 46: type = LLM_TYPE_27B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_STARCODER2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 30: model.type = e_model::MODEL_3B; break; - case 32: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_15B; break; - case 52: model.type = e_model::MODEL_20B; break; // granite - case 88: model.type = e_model::MODEL_34B; break; // granite - default: model.type = e_model::MODEL_UNKNOWN; + case 30: type = LLM_TYPE_3B; break; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_15B; break; + case 52: type = LLM_TYPE_20B; break; // granite + case 88: type = LLM_TYPE_34B; break; // granite + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_MAMBA: @@ -758,51 +889,51 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { switch (hparams.n_layer) { case 24: switch (hparams.n_embd) { - case 768: model.type = e_model::MODEL_SMALL; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 768: type = LLM_TYPE_SMALL; break; + default: type = LLM_TYPE_UNKNOWN; } break; case 48: switch (hparams.n_embd) { - case 1024: model.type = e_model::MODEL_MEDIUM; break; - case 1536: model.type = e_model::MODEL_LARGE; break; - case 2048: model.type = e_model::MODEL_XL; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 1024: type = LLM_TYPE_MEDIUM; break; + case 1536: type = LLM_TYPE_LARGE; break; + case 2048: type = LLM_TYPE_XL; break; + default: type = LLM_TYPE_UNKNOWN; } break; case 64: switch (hparams.n_embd) { - case 2560: model.type = e_model::MODEL_3B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 2560: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; } break; - default: model.type = e_model::MODEL_UNKNOWN; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_XVERSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_13B; break; - case 80: model.type = e_model::MODEL_65B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + case 80: type = LLM_TYPE_65B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_COMMAND_R: { - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 40: model.type = e_model::MODEL_35B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 40: type = LLM_TYPE_35B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_COHERE2: { ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_8B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_DBRX: @@ -811,8 +942,8 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); switch (hparams.n_layer) { - case 40: model.type = e_model::MODEL_16x12B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 40: type = LLM_TYPE_16x12B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_OLMO: @@ -821,10 +952,10 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); switch (hparams.n_layer) { - case 22: model.type = e_model::MODEL_1B; break; - case 32: model.type = e_model::MODEL_7B; break; - case 80: model.type = e_model::MODEL_70B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 22: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_7B; break; + case 80: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_OLMO2: @@ -832,18 +963,18 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 16: model.type = e_model::MODEL_1B; break; - case 32: model.type = e_model::MODEL_7B; break; - case 40: model.type = e_model::MODEL_13B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 16: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_OLMOE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 16: model.type = e_model::MODEL_A1_7B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 16: type = LLM_TYPE_A1_7B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_OPENELM: @@ -851,57 +982,57 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 16: model.type = e_model::MODEL_270M; break; - case 20: model.type = e_model::MODEL_450M; break; - case 28: model.type = e_model::MODEL_1B; break; - case 36: model.type = e_model::MODEL_3B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 16: type = LLM_TYPE_270M; break; + case 20: type = LLM_TYPE_450M; break; + case 28: type = LLM_TYPE_1B; break; + case 36: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_GPTNEOX: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); + ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); switch (hparams.n_layer) { case 6: switch (hparams.n_ff()) { - case 512: model.type = e_model::MODEL_14M; break; - case 2048: model.type = e_model::MODEL_70M; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 512: type = LLM_TYPE_14M; break; + case 2048: type = LLM_TYPE_70M; break; + default: type = LLM_TYPE_UNKNOWN; } break; case 12: switch (hparams.n_ff()) { - case 3072: model.type = e_model::MODEL_160M; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 3072: type = LLM_TYPE_160M; break; + default: type = LLM_TYPE_UNKNOWN; } break; case 16: switch (hparams.n_ff()) { - case 8192: model.type = e_model::MODEL_1B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 8192: type = LLM_TYPE_1B; break; + default: type = LLM_TYPE_UNKNOWN; } break; case 24: switch (hparams.n_ff()) { - case 4096: model.type = e_model::MODEL_410M; break; - case 8192: model.type = e_model::MODEL_1_4B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 4096: type = LLM_TYPE_410M; break; + case 8192: type = LLM_TYPE_1_4B; break; + default: type = LLM_TYPE_UNKNOWN; } break; case 32: switch (hparams.n_ff()) { - case 10240: model.type = e_model::MODEL_2_8B; break; - case 16384: model.type = e_model::MODEL_6_9B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 10240: type = LLM_TYPE_2_8B; break; + case 16384: type = LLM_TYPE_6_9B; break; + default: type = LLM_TYPE_UNKNOWN; } break; case 36: switch (hparams.n_ff()) { - case 20480: model.type = e_model::MODEL_12B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 20480: type = LLM_TYPE_12B; break; + default: type = LLM_TYPE_UNKNOWN; } break; case 44: switch (hparams.n_ff()) { - case 24576: model.type = e_model::MODEL_20B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 24576: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; } break; - default: model.type = e_model::MODEL_UNKNOWN; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_ARCTIC: @@ -910,40 +1041,40 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { if (hparams.n_expert == 128) { switch (hparams.n_layer) { - case 35: model.type = e_model::MODEL_10B_128x3_66B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 35: type = LLM_TYPE_10B_128x3_66B; break; + default: type = LLM_TYPE_UNKNOWN; } } else { - model.type = e_model::MODEL_UNKNOWN; + type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_DEEPSEEK: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); switch (hparams.n_layer) { - case 28: model.type = e_model::MODEL_20B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 28: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_DEEPSEEK2: { bool is_lite = (hparams.n_layer == 27); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); if (!is_lite) { ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); } - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { // for compatibility with existing DeepSeek V2 and V2.5 GGUFs // that have no expert_gating_func model parameter set @@ -952,19 +1083,19 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); switch (hparams.n_layer) { - case 27: model.type = e_model::MODEL_16B; break; - case 60: model.type = e_model::MODEL_236B; break; - case 61: model.type = e_model::MODEL_671B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 27: type = LLM_TYPE_16B; break; + case 60: type = LLM_TYPE_236B; break; + case 61: type = LLM_TYPE_671B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_CHATGLM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 28: model.type = e_model::MODEL_6B; break; - case 40: model.type = e_model::MODEL_9B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 28: type = LLM_TYPE_6B; break; + case 40: type = LLM_TYPE_9B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_BITNET: @@ -972,13 +1103,13 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 26: model.type = e_model::MODEL_3B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 26: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_T5: { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); uint32_t dec_start_token_id; @@ -987,32 +1118,32 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { } switch (hparams.n_layer) { - case 6: model.type = e_model::MODEL_60M; break; // t5-small - case 8: model.type = e_model::MODEL_80M; break; // flan-t5-small + case 6: type = LLM_TYPE_60M; break; // t5-small + case 8: type = LLM_TYPE_80M; break; // flan-t5-small case 12: switch (hparams.n_ff()) { - case 3072: model.type = e_model::MODEL_220M; break; // t5-base - case 2048: model.type = e_model::MODEL_250M; break; // flan-t5-base - default: model.type = e_model::MODEL_UNKNOWN; + case 3072: type = LLM_TYPE_220M; break; // t5-base + case 2048: type = LLM_TYPE_250M; break; // flan-t5-base + default: type = LLM_TYPE_UNKNOWN; } break; case 24: switch (hparams.n_ff()) { - case 4096: model.type = e_model::MODEL_770M; break; // t5-large - case 2816: model.type = e_model::MODEL_780M; break; // flan-t5-large - case 16384: model.type = e_model::MODEL_3B; break; // t5-3b - case 5120: model.type = e_model::MODEL_3B; break; // flan-t5-xl - case 65536: model.type = e_model::MODEL_11B; break; // t5-11b - case 10240: model.type = e_model::MODEL_11B; break; // flan-t5-xxl - default: model.type = e_model::MODEL_UNKNOWN; + case 4096: type = LLM_TYPE_770M; break; // t5-large + case 2816: type = LLM_TYPE_780M; break; // flan-t5-large + case 16384: type = LLM_TYPE_3B; break; // t5-3b + case 5120: type = LLM_TYPE_3B; break; // flan-t5-xl + case 65536: type = LLM_TYPE_11B; break; // t5-11b + case 10240: type = LLM_TYPE_11B; break; // flan-t5-xxl + default: type = LLM_TYPE_UNKNOWN; } break; - default: model.type = e_model::MODEL_UNKNOWN; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_T5ENCODER: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); - model.type = e_model::MODEL_UNKNOWN; + type = LLM_TYPE_UNKNOWN; } break; case LLM_ARCH_JAIS: { @@ -1020,18 +1151,18 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_1_3B; break; - case 40: model.type = e_model::MODEL_13B; break; + case 24: type = LLM_TYPE_1_3B; break; + case 40: type = LLM_TYPE_13B; break; /* TODO: add variants */ - default: model.type = e_model::MODEL_UNKNOWN; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_NEMOTRON: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_4B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_4B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_EXAONE: @@ -1039,44 +1170,48 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_8B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_RWKV6: + case LLM_ARCH_RWKV6QWEN2: { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); - ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); - ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); - ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); + ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); + ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); + ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); + ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); switch (hparams.n_layer) { - case 24: model.type = e_model::MODEL_1_6B; break; + case 24: type = LLM_TYPE_1_6B; break; case 32: switch (hparams.n_embd) { - case 2560: model.type = e_model::MODEL_3B; break; - case 4096: model.type = e_model::MODEL_7B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; } break; - case 61: model.type = e_model::MODEL_14B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 61: type = LLM_TYPE_14B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); - ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); - ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); - ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); + ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_3B; break; - case 40: model.type = e_model::MODEL_3B; break; + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_3B; break; // Add additional layer/vocab/etc checks here for other model sizes - default: model.type = e_model::MODEL_UNKNOWN; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_CHAMELEON: @@ -1086,9 +1221,9 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm); switch (hparams.n_layer) { - case 32: model.type = e_model::MODEL_7B; break; - case 48: model.type = e_model::MODEL_34B; break; - default: model.type = e_model::MODEL_UNKNOWN; + case 32: type = LLM_TYPE_7B; break; + case 48: type = LLM_TYPE_34B; break; + default: type = LLM_TYPE_UNKNOWN; } } break; case LLM_ARCH_WAVTOKENIZER_DEC: @@ -1101,753 +1236,2309 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) { default: throw std::runtime_error("unsupported model architecture"); } - model.ftype = ml.ftype; + pimpl->n_bytes = ml.n_bytes; + + pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name(); if (hparams.f_max_alibi_bias > 0.0f) { hparams.use_alibi = true; } - hparams.rope_type = llama_rope_type(&model); + hparams.rope_type = llama_model_rope_type(this); } -void llm_load_vocab(llama_model_loader & ml, llama_model & model) { - auto & vocab = model.vocab; +void llama_model::load_vocab(llama_model_loader & ml) { + const auto kv = LLM_KV(arch); - struct gguf_context * ctx = ml.meta.get(); - - const auto kv = LLM_KV(model.arch); - - // determine vocab type - { - std::string tokenizer_model; - std::string tokenizer_pre; - - ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model); - ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); - - if (tokenizer_model == "no_vocab" || tokenizer_model == "none") { - vocab.type = LLAMA_VOCAB_TYPE_NONE; - - // default special tokens - vocab.special_bos_id = LLAMA_TOKEN_NULL; - vocab.special_eos_id = LLAMA_TOKEN_NULL; - vocab.special_unk_id = LLAMA_TOKEN_NULL; - vocab.special_sep_id = LLAMA_TOKEN_NULL; - vocab.special_pad_id = LLAMA_TOKEN_NULL; - vocab.special_cls_id = LLAMA_TOKEN_NULL; - vocab.special_mask_id = LLAMA_TOKEN_NULL; - vocab.linefeed_id = LLAMA_TOKEN_NULL; - - // read vocab size from metadata - if (!ml.get_key(LLM_KV_VOCAB_SIZE, vocab.n_vocab, false)) { - vocab.n_vocab = 0; - LLAMA_LOG_WARN("%s: there is no vocab_size in metadata, vocab.n_vocab will be set to %u\n", __func__, vocab.n_vocab); - } - return; - } - - if (tokenizer_model == "llama") { - vocab.type = LLAMA_VOCAB_TYPE_SPM; - - // default special tokens - vocab.special_bos_id = 1; - vocab.special_eos_id = 2; - vocab.special_unk_id = 0; - vocab.special_sep_id = LLAMA_TOKEN_NULL; - vocab.special_pad_id = LLAMA_TOKEN_NULL; - vocab.special_cls_id = LLAMA_TOKEN_NULL; - vocab.special_mask_id = LLAMA_TOKEN_NULL; - } else if (tokenizer_model == "bert") { - vocab.type = LLAMA_VOCAB_TYPE_WPM; - - // default special tokens - vocab.special_bos_id = LLAMA_TOKEN_NULL; - vocab.special_eos_id = LLAMA_TOKEN_NULL; - vocab.special_unk_id = 100; - vocab.special_sep_id = 102; - vocab.special_pad_id = 0; - vocab.special_cls_id = 101; - vocab.special_mask_id = 103; - } else if (tokenizer_model == "gpt2") { - vocab.type = LLAMA_VOCAB_TYPE_BPE; - - // read bpe merges and populate bpe ranks - const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); - if (merges_keyidx == -1) { - throw std::runtime_error("cannot find tokenizer merges in model file\n"); - } - - const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); - for (int i = 0; i < n_merges; i++) { - const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); - GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); - - std::string first; - std::string second; - - const size_t pos = word.find(' ', 1); - - if (pos != std::string::npos) { - first = word.substr(0, pos); - second = word.substr(pos + 1); - } + vocab.load(ml, kv); +} - vocab.bpe_ranks.emplace(std::make_pair(first, second), i); - } +bool llama_model::load_tensors(llama_model_loader & ml) { + const auto & split_mode = params.split_mode; + const auto & n_gpu_layers = params.n_gpu_layers; + const auto & use_mlock = params.use_mlock; + const auto & tensor_split = params.tensor_split; - // default special tokens - vocab.special_bos_id = 11; - vocab.special_eos_id = 11; - vocab.special_unk_id = LLAMA_TOKEN_NULL; - vocab.special_sep_id = LLAMA_TOKEN_NULL; - vocab.special_pad_id = LLAMA_TOKEN_NULL; - vocab.special_cls_id = LLAMA_TOKEN_NULL; - vocab.special_mask_id = LLAMA_TOKEN_NULL; - } else if (tokenizer_model == "t5") { - vocab.type = LLAMA_VOCAB_TYPE_UGM; - - // default special tokens - vocab.special_bos_id = LLAMA_TOKEN_NULL; - vocab.special_eos_id = 1; - vocab.special_unk_id = 2; - vocab.special_sep_id = LLAMA_TOKEN_NULL; - vocab.special_pad_id = 0; - vocab.special_cls_id = LLAMA_TOKEN_NULL; - vocab.special_mask_id = LLAMA_TOKEN_NULL; - - const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str()); - if (precompiled_charsmap_keyidx != -1) { - size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx); - const char * precompiled_charsmap = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); - vocab.precompiled_charsmap.assign(precompiled_charsmap, precompiled_charsmap + n_precompiled_charsmap); -#ifdef IS_BIG_ENDIAN - // correct endiannes of data in precompiled_charsmap binary blob - uint32_t * xcda_blob_size = (uint32_t *) &vocab.precompiled_charsmap[0]; - *xcda_blob_size = __builtin_bswap32(*xcda_blob_size); - assert(*xcda_blob_size + sizeof(uint32_t) < n_precompiled_charsmap); - size_t xcda_array_size = *xcda_blob_size / sizeof(uint32_t); - uint32_t * xcda_array = (uint32_t *) &vocab.precompiled_charsmap[sizeof(uint32_t)]; - for (size_t i = 0; i < xcda_array_size; ++i) { - xcda_array[i] = __builtin_bswap32(xcda_array[i]); - } -#endif - } - } else if (tokenizer_model == "rwkv") { - vocab.type = LLAMA_VOCAB_TYPE_RWKV; - - // default special tokens - vocab.special_bos_id = LLAMA_TOKEN_NULL; - vocab.special_eos_id = LLAMA_TOKEN_NULL; - vocab.special_unk_id = LLAMA_TOKEN_NULL; - vocab.special_sep_id = LLAMA_TOKEN_NULL; - vocab.special_pad_id = LLAMA_TOKEN_NULL; - } else { - throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); - } + const int n_layer = hparams.n_layer; - // for now, only BPE models have pre-tokenizers - if (vocab.type == LLAMA_VOCAB_TYPE_BPE) { - vocab.tokenizer_add_space_prefix = false; - vocab.tokenizer_clean_spaces = true; - if (tokenizer_pre.empty()) { - LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__); - LLAMA_LOG_WARN("%s: \n", __func__); - LLAMA_LOG_WARN("%s: ************************************ \n", __func__); - LLAMA_LOG_WARN("%s: GENERATION QUALITY WILL BE DEGRADED! \n", __func__); - LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL \n", __func__); - LLAMA_LOG_WARN("%s: ************************************ \n", __func__); - LLAMA_LOG_WARN("%s: \n", __func__); - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - } else if (tokenizer_pre == "default") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - } else if ( - tokenizer_pre == "llama3" || - tokenizer_pre == "llama-v3" || - tokenizer_pre == "llama-bpe"|| - tokenizer_pre == "falcon3") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3; - vocab.tokenizer_ignore_merges = true; - vocab.tokenizer_add_bos = true; - } else if ( - tokenizer_pre == "deepseek-llm") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "deepseek-coder") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "deepseek-v3") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "falcon") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON; - } else if ( - tokenizer_pre == "mpt") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_MPT; - } else if ( - tokenizer_pre == "starcoder") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STARCODER; - } else if ( - tokenizer_pre == "gpt-2" || - tokenizer_pre == "phi-2" || - tokenizer_pre == "jina-es" || - tokenizer_pre == "jina-de" || - tokenizer_pre == "gigachat" || - tokenizer_pre == "jina-v1-en" || - tokenizer_pre == "jina-v2-es" || - tokenizer_pre == "jina-v2-de" || - tokenizer_pre == "jina-v2-code" || - tokenizer_pre == "roberta-bpe") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT2; - } else if ( - tokenizer_pre == "refact") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_REFACT; - } else if ( - tokenizer_pre == "command-r") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_COMMAND_R; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "qwen2") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "stablelm2") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STABLELM2; - } else if ( - tokenizer_pre == "olmo") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_OLMO; - } else if ( - tokenizer_pre == "dbrx") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DBRX; - } else if ( - tokenizer_pre == "smaug-bpe") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG; - } else if ( - tokenizer_pre == "poro-chat") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "chatglm-bpe") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4; - vocab.special_bos_id = LLAMA_TOKEN_NULL; - } else if ( - tokenizer_pre == "viking") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_VIKING; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "jais") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS; - } else if ( - tokenizer_pre == "tekken") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_TEKKEN; - vocab.tokenizer_clean_spaces = false; - vocab.tokenizer_ignore_merges = true; - vocab.tokenizer_add_bos = true; - } else if ( - tokenizer_pre == "smollm") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMOLLM; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "codeshell") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CODESHELL; - } else if ( - tokenizer_pre == "bloom") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_BLOOM; - } else if ( - tokenizer_pre == "gpt3-finnish") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH; - } else if ( - tokenizer_pre == "exaone") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_EXAONE; - } else if ( - tokenizer_pre == "chameleon") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHAMELEON; - vocab.tokenizer_add_bos = true; - vocab.tokenizer_clean_spaces = false; - } else if ( - tokenizer_pre == "minerva-7b") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_MINERVA; - } else if ( - tokenizer_pre == "megrez") { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2; - } else { - throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); - } - } else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - vocab.tokenizer_add_space_prefix = true; - vocab.tokenizer_clean_spaces = false; - vocab.tokenizer_add_bos = true; - vocab.tokenizer_add_eos = false; - } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - vocab.tokenizer_add_space_prefix = false; - vocab.tokenizer_clean_spaces = true; - vocab.tokenizer_add_bos = true; - vocab.tokenizer_add_eos = false; - } else if (vocab.type == LLAMA_VOCAB_TYPE_UGM) { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - vocab.tokenizer_add_bos = false; - vocab.tokenizer_add_eos = true; - } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - vocab.tokenizer_add_space_prefix = false; - vocab.tokenizer_clean_spaces = false; - vocab.tokenizer_add_bos = false; - vocab.tokenizer_add_eos = false; - } else { - vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; - } + const bool use_mmap_buffer = true; - ml.get_key(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.tokenizer_add_space_prefix, false); - ml.get_key(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.tokenizer_remove_extra_whitespaces, false); + // build a list of buffer types for the CPU and GPU devices + pimpl->cpu_buft_list = make_cpu_buft_list(devices); + for (auto * dev : devices) { + buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split); + // add CPU buffer types as a fallback + buft_list.insert(buft_list.end(), pimpl->cpu_buft_list.begin(), pimpl->cpu_buft_list.end()); + pimpl->gpu_buft_list.emplace(dev, std::move(buft_list)); } - const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str()); - if (token_idx == -1) { - throw std::runtime_error("cannot find tokenizer vocab in model file\n"); + // calculate the split points + bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + n_devices(), [](float x) { return x == 0.0f; }); + std::vector splits(n_devices()); + if (all_zero) { + // default split, by free memory + for (size_t i = 0; i < n_devices(); ++i) { + ggml_backend_dev_t dev = devices[i]; + size_t total; + size_t free; + ggml_backend_dev_memory(dev, &free, &total); + splits[i] = free; + } + } else { + std::copy(tensor_split, tensor_split + n_devices(), splits.begin()); } - const float * scores = nullptr; - const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str()); - if (score_idx != -1) { - scores = (const float * ) gguf_get_arr_data(ctx, score_idx); + // sum and normalize the splits to get the split points + float split_sum = 0.0f; + for (size_t i = 0; i < n_devices(); ++i) { + split_sum += splits[i]; + splits[i] = split_sum; } - - const int * toktypes = nullptr; - const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); - if (toktype_idx != -1) { - toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); + for (size_t i = 0; i < n_devices(); ++i) { + splits[i] /= split_sum; } - const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx); + ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0); + const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1); + auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev { + if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) { + return {cpu_dev, &pimpl->cpu_buft_list}; + } + const int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + n_devices(), float(il - i_gpu_start)/act_gpu_layers) - splits.begin(); + auto * dev = devices.at(layer_gpu); + return {dev, &pimpl->gpu_buft_list.at(dev)}; + }; - vocab.n_vocab = n_vocab; - vocab.id_to_token.resize(n_vocab); + // assign the input layer + // there is very little benefit to offloading the input layer, so always keep it on the CPU + pimpl->dev_input = { cpu_dev, &pimpl->cpu_buft_list }; - for (uint32_t i = 0; i < n_vocab; i++) { - std::string word = gguf_get_arr_str(ctx, token_idx, i); - if (word.empty()) { - LLAMA_LOG_WARN("%s: empty token at index %u\n", __func__, i); - word = "[EMPTY_" + std::to_string(i) + "]"; - } + // assign the repeating layers to the devices according to the splits + pimpl->dev_layer.resize(n_layer); + for (int il = 0; il < n_layer; ++il) { + pimpl->dev_layer[il] = get_layer_buft_list(il); + } - vocab.token_to_id[word] = i; - vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size()); - - auto & token_data = vocab.id_to_token[i]; - token_data.text = std::move(word); - token_data.score = scores ? scores[i] : 0.0f; - token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; - - if (toktypes) { //TODO: remove, required until per token attributes are available from GGUF file - switch(toktypes[i]) { - case LLAMA_TOKEN_TYPE_UNKNOWN: token_data.attr = LLAMA_TOKEN_ATTR_UNKNOWN; break; - case LLAMA_TOKEN_TYPE_UNUSED: token_data.attr = LLAMA_TOKEN_ATTR_UNUSED; break; - case LLAMA_TOKEN_TYPE_NORMAL: token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; break; - case LLAMA_TOKEN_TYPE_CONTROL: token_data.attr = LLAMA_TOKEN_ATTR_CONTROL; break; - case LLAMA_TOKEN_TYPE_USER_DEFINED: token_data.attr = LLAMA_TOKEN_ATTR_USER_DEFINED; break; - case LLAMA_TOKEN_TYPE_BYTE: token_data.attr = LLAMA_TOKEN_ATTR_BYTE; break; - case LLAMA_TOKEN_TYPE_UNDEFINED: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; - default: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; + // assign the output layer + pimpl->dev_output = get_layer_buft_list(n_layer); + + // one ggml context per buffer type + int max_n_tensors = ml.n_tensors; + max_n_tensors += 1; // duplicated output tensor + max_n_tensors += n_layer*2; // duplicated rope freq tensors + const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors; + + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ ctx_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + throw std::runtime_error(format("failed to create ggml context")); } - } - } - GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size()); - vocab.init_tokenizer(); + ctx_map[buft] = ctx; + pimpl->ctxs.emplace_back(ctx); - // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' - if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { - try { - vocab.linefeed_id = llama_byte_to_token_impl(vocab, '\n'); - } catch (const std::exception & e) { - LLAMA_LOG_WARN("%s: SPM vocabulary, but newline token not found: %s! Using special_pad_id instead.", __func__, e.what()); - vocab.linefeed_id = vocab.special_pad_id; + return ctx; } - } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) { - vocab.linefeed_id = vocab.special_pad_id; - } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) { - const std::vector ids = llama_tokenize_internal(vocab, "\n", false); - GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); - vocab.linefeed_id = ids[0]; - } else { - const std::vector ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A + return it->second; + }; - //GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); - if (ids.empty()) { - LLAMA_LOG_WARN("%s: model vocab missing newline token, using special_pad_id instead\n", __func__); - vocab.linefeed_id = vocab.special_pad_id; - } else { - vocab.linefeed_id = ids[0]; - } - } + const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED; + const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED; - // special tokens + // create tensors for the weights { - const std::vector> special_token_types = { - { LLM_KV_TOKENIZER_BOS_ID, vocab.special_bos_id }, - { LLM_KV_TOKENIZER_EOS_ID, vocab.special_eos_id }, - { LLM_KV_TOKENIZER_EOT_ID, vocab.special_eot_id }, - { LLM_KV_TOKENIZER_EOM_ID, vocab.special_eom_id }, - { LLM_KV_TOKENIZER_UNK_ID, vocab.special_unk_id }, - { LLM_KV_TOKENIZER_SEP_ID, vocab.special_sep_id }, - { LLM_KV_TOKENIZER_PAD_ID, vocab.special_pad_id }, - { LLM_KV_TOKENIZER_CLS_ID, vocab.special_cls_id }, - { LLM_KV_TOKENIZER_MASK_ID, vocab.special_mask_id }, - { LLM_KV_TOKENIZER_FIM_PRE_ID, vocab.special_fim_pre_id }, - { LLM_KV_TOKENIZER_FIM_SUF_ID, vocab.special_fim_suf_id }, - { LLM_KV_TOKENIZER_FIM_MID_ID, vocab.special_fim_mid_id }, - { LLM_KV_TOKENIZER_FIM_PAD_ID, vocab.special_fim_pad_id }, - { LLM_KV_TOKENIZER_FIM_REP_ID, vocab.special_fim_rep_id }, - { LLM_KV_TOKENIZER_FIM_SEP_ID, vocab.special_fim_sep_id }, - - // deprecated - { LLM_KV_TOKENIZER_PREFIX_ID, vocab.special_fim_pre_id }, - { LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_fim_suf_id }, - { LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_fim_mid_id }, - }; + // note: cast to int64_t since we will use these for the tensor dimensions + const int64_t n_head = hparams.n_head(); + const int64_t n_head_kv = hparams.n_head_kv(); + const int64_t n_embd = hparams.n_embd; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + const int64_t n_embd_head_k = hparams.n_embd_head_k; + const int64_t n_embd_head_v = hparams.n_embd_head_v; + const int64_t n_ff = hparams.n_ff(); + const int64_t n_embd_gqa = n_embd_v_gqa; + const int64_t n_vocab = vocab.n_tokens(); + const int64_t n_token_types = vocab.n_token_types(); + const int64_t n_rot = hparams.n_rot; + const int64_t n_expert = hparams.n_expert; + const int64_t n_expert_used = hparams.n_expert_used; + const int64_t n_ctx_train = hparams.n_ctx_train; + + if (n_expert > 0 && hparams.n_expert_used == 0) { + throw std::runtime_error("model has expert layers but no expert layers are used"); + } - for (const auto & it : special_token_types) { - const std::string & key = kv(std::get<0>(it)); - int32_t & id = std::get<1>(it); + int n_moved_tensors = 0; + ggml_tensor * first_moved_tensor = nullptr; + ggml_backend_buffer_type_t first_moved_from_buft = nullptr; + ggml_backend_buffer_type_t first_moved_to_buft = nullptr; - uint32_t new_id; - if (!ml.get_key(std::get<0>(it), new_id, false)) { - continue; - } - if (new_id >= vocab.id_to_token.size()) { - LLAMA_LOG_WARN("%s: bad special token: '%s' = %ud, using default id %d\n", - __func__, key.c_str(), new_id, id); - } else { - id = new_id; + auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) -> ggml_tensor * { + ggml_tensor * t_meta = ml.get_tensor_meta(tn.str().c_str()); + + if (!t_meta) { + if (flags & TENSOR_NOT_REQUIRED) { + return nullptr; + } + throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str())); } - } - // Handle add_bos_token and add_eos_token - { - bool temp = true; + // some models use the token embedding tensor as the output, but since these are used in different layers and with different ops + // the tensor is duplicated + // to handle this, we check if the tensor is duplicated, and if so, we assume that it is being loaded as the output tensor + llm_tensor tn_tensor = tn.tensor; + if (tn.tensor == LLM_TENSOR_TOKEN_EMBD && flags & TENSOR_DUPLICATED) { + tn_tensor = LLM_TENSOR_OUTPUT; + } - if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) { - vocab.tokenizer_add_bos = temp; + llm_tensor_info info; + try { + info = llm_tensor_info_for(tn_tensor); + } catch (const std::out_of_range & e) { + throw std::runtime_error(format("missing tensor info mapping for %s", tn.str().c_str())); } - if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) { - vocab.tokenizer_add_eos = temp; + + // tensors with "bias" suffix are always used with GGML_OP_ADD + ggml_op op; + bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0; + if (bias) { + op = GGML_OP_ADD; + } else { + op = info.op; } - } - // auto-detect special tokens by text - // TODO: convert scripts should provide these tokens through the KV metadata LLM_KV_TOKENIZER_... - // for now, we apply this workaround to find the tokens based on their text - - for (const auto & t : vocab.token_to_id) { - // find EOT token: "<|eot_id|>", "<|im_end|>", "", etc. - if (vocab.special_eot_id == LLAMA_TOKEN_NULL) { - if (false - || t.first == "<|eot_id|>" - || t.first == "<|im_end|>" - || t.first == "<|end|>" - || t.first == "" - || t.first == "<|endoftext|>" - || t.first == "" - || t.first == "<|end▁of▁sentence|>" // DeepSeek - ) { - vocab.special_eot_id = t.second; - if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { - LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", - __func__, t.second, t.first.c_str()); - vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; - } + // sanity checks + if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) { + if (tn.bid != -1) { + GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str()); + } + } else { + if (tn.bid == -1) { + GGML_ABORT("repeating layer tensor %s used without a layer number", tn.str().c_str()); } } - // find EOM token: "<|eom_id|>" - if (vocab.special_eom_id == LLAMA_TOKEN_NULL) { - if (false - || t.first == "<|eom_id|>" - ) { - vocab.special_eom_id = t.second; - if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { - LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", - __func__, t.second, t.first.c_str()); - vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; - } - } + // select the buffer type for this tensor + buft_list_t * buft_list; + switch (info.layer) { + case LLM_TENSOR_LAYER_INPUT: + buft_list = pimpl->dev_input.buft_list; + break; + case LLM_TENSOR_LAYER_OUTPUT: + buft_list = pimpl->dev_output.buft_list; + break; + case LLM_TENSOR_LAYER_REPEATING: + buft_list = pimpl->dev_layer.at(tn.bid).buft_list; + break; + default: + GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str()); } - // find FIM_PRE token: "<|fim_prefix|>", "", "
", etc.
-            if (vocab.special_fim_pre_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|fim_prefix|>"  // Qwen
-                        || t.first == ""
-                        || t.first == "<|fim▁begin|>" // DeepSeek
-                        || t.first == "
"
-                        ) {
-                    vocab.special_fim_pre_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
+            ggml_backend_buffer_type_t buft = select_weight_buft(hparams, t_meta, op, *buft_list);
+            if (!buft) {
+                throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str()));
             }
 
-            // find FIM_SUF token: "<|fim_suffix|>", "", "", etc.
-            if (vocab.special_fim_suf_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|fim_suffix|>" // Qwen
-                        || t.first == ""
-                        || t.first == "<|fim▁hole|>" // DeepSeek
-                        || t.first == ""
-                        ) {
-                    vocab.special_fim_suf_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
-                }
+            // avoid using a host buffer when using mmap
+            auto * buft_dev = ggml_backend_buft_get_device(buft);
+            if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) {
+                auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
+                buft = ggml_backend_dev_buffer_type(cpu_dev);
             }
 
-            // find FIM_MID token: "<|fim_middle|>", "", "", etc.
-            if (vocab.special_fim_mid_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|fim_middle|>" // Qwen
-                        || t.first == ""
-                        || t.first == "<|fim▁end|>"  // DeepSeek
-                        || t.first == ""
-                        ) {
-                    vocab.special_fim_mid_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
+            if (buft != buft_list->front().second) {
+                n_moved_tensors++;
+                if (!first_moved_tensor) {
+                    first_moved_tensor = t_meta;
+                    first_moved_from_buft = buft_list->front().second;
+                    first_moved_to_buft   = buft;
                 }
             }
 
-            // find FIM_PAD token: "<|fim_pad|>", "", "", etc.
-            if (vocab.special_fim_pad_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|fim_pad|>" // Qwen
-                        || t.first == ""
-                        || t.first == ""
-                        ) {
-                    vocab.special_fim_pad_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                    }
+            ggml_context * ctx = ctx_for_buft(buft);
+
+            // if duplicated, check if the original tensor was allocated in the same buffer type context and avoid creating a new one
+            if (flags & TENSOR_DUPLICATED) {
+                ggml_tensor * t = ggml_get_tensor(ctx, tn.str().c_str());
+                if (t) {
+                    return t;
                 }
             }
+            return ml.create_tensor(ctx, tn, ne, flags);
+        };
 
-            // find FIM_REP token: "<|fim_repo|>", "", "", etc.
-            if (vocab.special_fim_rep_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|fim_repo|>"  // Qwen
-                        || t.first == "<|repo_name|>"
-                        || t.first == ""
-                        || t.first == ""
-                        ) {
-                    vocab.special_fim_rep_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+        layers.resize(n_layer);
+
+        // TODO: move to a separate function
+        const auto tn = LLM_TN(arch);
+        switch (arch) {
+            case LLM_ARCH_LLAMA:
+            case LLM_ARCH_REFACT:
+            case LLM_ARCH_MINICPM:
+            case LLM_ARCH_GRANITE:
+            case LLM_ARCH_GRANITE_MOE:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
                     }
-                }
-            }
 
-            // find FIM_SEP token: "<|file_sep|>"
-            if (vocab.special_fim_sep_id == LLAMA_TOKEN_NULL) {
-                if (false
-                        || t.first == "<|file_sep|>" // Qwen
-                        ) {
-                    vocab.special_fim_sep_id = t.second;
-                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.second, t.first.c_str());
-                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+                        // optional bias tensors
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     TENSOR_NOT_REQUIRED);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
+                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     TENSOR_NOT_REQUIRED);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
+                            layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+                            layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+                        }
+                        else {
+                            layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+                        }
+
+                        if (n_expert == 0) {
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+
+                            // optional MLP bias
+                            layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED);
+                            layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
+                            layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, TENSOR_NOT_REQUIRED);
+                        } else {
+                            layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert}, 0);
+                            layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd,   n_ff, n_expert}, TENSOR_NOT_REQUIRED);
+                            layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert}, 0);
+                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert}, 0);
+                        }
+                    }
+                } break;
+            case LLM_ARCH_DECI:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
                     }
-                }
-            }
-        }
 
-        // maintain a list of tokens that cause end-of-generation
-        // this is currently determined based on the token text, which is obviously not ideal
-        // ref: https://github.com/ggerganov/llama.cpp/issues/9606
-        vocab.special_eog_ids.clear();
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+                        const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa(i);
+                        const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa(i);
+                        const int64_t n_embd_gqa    = hparams.n_embd_v_gqa(i);
+                        const int64_t n_ff          = hparams.n_ff(i);
+                        const int64_t n_head        = hparams.n_head(i);
+                        const int64_t n_head_kv     = hparams.n_head_kv(i);
+
+                        if (n_head_kv == 0 && n_head > 0) {
+                            // linear attention for DeciLMCausalModel
+                            layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                            layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        }
+                        else if (n_head_kv > 0) {
+                            layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                            layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+                        }
+
+                        // optional bias tensors
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     TENSOR_NOT_REQUIRED);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
+                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     TENSOR_NOT_REQUIRED);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
+                            layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+                            layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+                        }
+                        else {
+                            layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+                        }
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+
+                        // optional MLP bias
+                        layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, TENSOR_NOT_REQUIRED);
+                    }
+                } break;
+            case LLM_ARCH_MINICPM3:
+                {
+                    const int64_t n_embd_head_qk_rope = hparams.n_rot;
+                    const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+
+                    const int64_t q_lora_rank  = hparams.n_lora_q;
+                    const int64_t kv_lora_rank = hparams.n_lora_kv;
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
 
-        if (vocab.special_fim_pad_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_pad_id) == 0) {
-            vocab.special_eog_ids.insert(vocab.special_fim_pad_id);
-        }
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
 
-        if (vocab.special_fim_rep_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_rep_id) == 0) {
-            vocab.special_eog_ids.insert(vocab.special_fim_rep_id);
-        }
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0);
 
-        if (vocab.special_fim_sep_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_sep_id) == 0) {
-            vocab.special_eog_ids.insert(vocab.special_fim_sep_id);
-        }
+                        layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0);
 
-        for (const auto & t : vocab.token_to_id) {
-            if (false
-                    || t.first == "<|eot_id|>"
-                    || t.first == "<|im_end|>"
-                    || t.first == "<|end|>"
-                    || t.first == ""
-                    || t.first == "<|endoftext|>"
-                    || t.first == "<|eom_id|>"
-                    || t.first == ""
-               ) {
-                vocab.special_eog_ids.insert(t.second);
-                if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                    LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                            __func__, t.second, t.first.c_str());
-                    vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
-                }
-            } else {
-                // token is control, but not marked as EOG -> print a debug log
-                if (vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && vocab.special_eog_ids.count(t.second) == 0) {
-                    LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n",
-                            __func__, t.second, t.first.c_str());
-                }
-            }
-        }
+                        layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0);
+                        layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0);
 
-        // sanity checks
-        if (vocab.special_eos_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eos_id) == 0) {
-            vocab.special_eog_ids.insert(vocab.special_eos_id);
-            LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
-        }
+                        layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
+                        layer.wkv_b     = create_tensor(tn(LLM_TENSOR_ATTN_KV_B,     "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
+                        layer.wo        = create_tensor(tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {              n_head * (                      n_embd_head_v), n_embd}, 0);
 
-        if (vocab.special_eot_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eot_id) == 0) {
-            vocab.special_eog_ids.insert(vocab.special_eot_id);
-            LLAMA_LOG_WARN("%s: special_eot_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
-        }
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
 
-        if (vocab.special_eom_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eom_id) == 0) {
-            vocab.special_eog_ids.insert(vocab.special_eom_id);
-            LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
-        }
-    }
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
 
-    // build special tokens cache
-    {
-        for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
-            if (vocab.id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
-                vocab.cache_special_tokens.push_back(id);
-            }
-        }
+                        layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), { n_embd_head_qk_rope/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+                        layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head_qk_rope/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+                    }
+                } break;
+            case LLM_ARCH_GROK:
+                {
+                    if (n_expert == 0) {
+                        throw std::runtime_error("Grok model cannot have zero experts");
+                    }
 
-        std::sort(vocab.cache_special_tokens.begin(), vocab.cache_special_tokens.end(),
-            [&] (const llama_vocab::id a, const llama_vocab::id b) {
-                return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
-            }
-        );
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
-        LLAMA_LOG_INFO("%s: special tokens cache size = %u\n", __func__, (uint32_t)vocab.cache_special_tokens.size());
-    }
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
 
-    // build token to piece cache
-    {
-        size_t size_cache = 0;
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
 
-        std::vector cache_token_to_piece(n_vocab);
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
 
-        for (uint32_t id = 0; id < n_vocab; ++id) {
-            cache_token_to_piece[id] = llama_token_to_piece(&model, id, true);
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
 
-            size_cache += cache_token_to_piece[id].size();
-        }
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
 
-        std::swap(vocab.cache_token_to_piece, cache_token_to_piece);
+                        layer.attn_out_norm   = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
 
-        LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0);
-    }
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
 
-    // Handle per token attributes
-    //NOTE: Each model customizes per token attributes.
-    //NOTE: Per token attributes are missing from the GGUF file.
-    //TODO: Extract attributes from GGUF file.
-    {
-        auto _contains_any = [] (const std::string &str, const std::vector &substrs) -> bool {
-            for (auto substr : substrs) {
-                if (str.find(substr) < std::string::npos) {
-                    return true;
-                }
-            }
-            return false;
-        };
+                        layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert}, 0);
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert}, 0);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert}, 0);
 
-        auto _set_tokenid_attr = [&] (const llama_vocab::id id, llama_token_attr attr, bool value) {
-            uint32_t current = vocab.id_to_token.at(id).attr;
-            current = value ? (current | attr) : (current & ~attr);
-            vocab.id_to_token[id].attr = (llama_token_attr) current;
-        };
+                        layer.layer_out_norm   = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
+                    }
+                } break;
+            case LLM_ARCH_DBRX:
+                {
+                    if (n_expert == 0) {
+                        throw std::runtime_error("DBRX model cannot have zero experts");
+                    }
 
-        auto _set_token_attr = [&] (const std::string & token, llama_token_attr attr, bool value) {
-            _set_tokenid_attr(vocab.token_to_id.at(token), attr, value);
-        };
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
-        std::string model_name;
-        std::string tokenizer_pre;
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
 
-        ml.get_key(LLM_KV_GENERAL_NAME, model_name, false);
-        ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
 
-        // model name to lowercase
-        std::transform(model_name.begin(), model_name.end(), model_name.begin(),
-            [] (const std::string::value_type x) {
-                return std::tolower(x);
-            }
-        );
-
-        // set attributes by model/tokenizer name
-        if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) {
-            _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true);
-        } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
-            for (auto id : vocab.cache_special_tokens) {
-                _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
-            }
-            for (auto token : {""}) {
-                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true);
-            }
-            for (auto token : {"", "", "<|endoftext|>"}) {
-                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
-            }
-        }
-    }
-}
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
 
-void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
-    const auto & hparams = model.hparams;
-    const auto & vocab   = model.vocab;
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
 
-    const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
+                        layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
 
-    auto print_f = [](const std::function & f, uint32_t n) {
-        bool is_var = false;
+                        layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert}, 0);
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff,   n_expert}, 0);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff,   n_embd, n_expert}, 0);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd, n_ff,   n_expert}, 0);
+                    }
+                } break;
+            case LLM_ARCH_BAICHUAN:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+                    {
+                        output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                        output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+                    }
 
-        std::vector v;
-        for (uint32_t i = 0; i < n; ++i) {
-            v.push_back(f(i));
-            if (v[i] != v[0]) {
-                is_var = true;
-            }
-        }
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
 
-        std::stringstream ss;
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
 
-        if (is_var) {
-            ss << "[";
-            for (uint32_t i = 0; i < n; ++i) {
-                ss << v[i];
-                if (i < n - 1) {
-                    ss << ", ";
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_FALCON:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    {
+                        output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                        output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+
+                        output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                        if (!output) {
+                            output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // needs to be on GPU
+                        }
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.attn_norm_2   = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
+                        layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i),   {n_embd}, TENSOR_NOT_REQUIRED);
+
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_STARCODER:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+                    pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train}, 0);
+
+                    // output
+                    {
+                        output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                        output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                        output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                        if (!output) {
+                            // needs to be on GPU
+                            output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                        }
+
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
+
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i),   {n_embd, n_ff}, 0);
+                        layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i),     {n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_BERT:
+            case LLM_ARCH_NOMIC_BERT:
+                {
+                    tok_embd     = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, 0);
+                    type_embd    = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0);
+
+                    if (arch == LLM_ARCH_BERT) {
+                        pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD,    "weight"), {n_embd, n_ctx_train}, 0);
+
+                        cls   = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED);
+                        cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"),   {n_embd},         TENSOR_NOT_REQUIRED);
+
+                        cls_out   = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED);
+                        cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"),   {1},         TENSOR_NOT_REQUIRED);
+                    }
+
+                    tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
+                    tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        if (arch == LLM_ARCH_BERT) {
+                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                            layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i),   {n_embd}, 0);
+
+                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                            layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i),   {n_embd_gqa}, 0);
+
+                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                            layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i),   {n_embd_gqa}, 0);
+                        } else {
+                            layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        }
+
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.attn_out_norm   = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,        "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN,      "weight", i), {n_ff, n_embd}, 0);
+
+                        if (arch == LLM_ARCH_BERT) {
+                            layer.bo         = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
+                            layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, 0);
+                            layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
+                        } else {
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+                        }
+
+                        layer.layer_out_norm   = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
+                        layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i),   {n_embd}, 0);
+                    }
+                } break;
+            case LLM_ARCH_JINA_BERT_V2:
+                {
+                    tok_embd  = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, 0); // word_embeddings
+                    type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); // token_type_embeddings
+
+                    tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); // LayerNorm
+                    tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd}, 0); //LayerNorm bias
+
+                    cls   = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED);
+                    cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"),   {1},         TENSOR_NOT_REQUIRED);
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i]; // JinaBertLayer
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i),   {n_embd}, 0);
+
+                        layer.attn_q_norm   = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
+                        layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias",   i), {n_embd}, TENSOR_NOT_REQUIRED);
+
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias",   i), {n_embd_gqa}, 0);
+
+                        layer.attn_k_norm   = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
+                        layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias",   i), {n_embd}, TENSOR_NOT_REQUIRED);
+
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias",   i), {n_embd_gqa}, 0);
+
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); //output_dens
+                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias",   i), {n_embd}, 0); //output_dens
+
+                        layer.attn_out_norm   = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); //output_norm
+                        layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias",   i), {n_embd}, 0);
+
+                        layer.attn_norm_2   = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
+                        layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias",   i), {n_embd}, TENSOR_NOT_REQUIRED);
+
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias",   i), {n_embd}, 0);
+
+                        layer.layer_out_norm   = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
+                        layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias",   i), {n_embd}, 0);
+                    }
+                } break;
+            case LLM_ARCH_BLOOM:
+                {
+                    tok_embd   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab}, 0);
+                    tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
+                    tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd}, 0);
+
+                    // output
+                    output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias",   i), {n_embd}, 0);
+
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias",   i), {n_embd + 2*n_embd_gqa}, 0);
+
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias",   i), {n_embd}, 0);
+
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias",   i), {n_embd}, 0);
+
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias",   i), {n_embd}, 0);
+
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias",   i), {n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_MPT:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+                    pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train}, TENSOR_NOT_REQUIRED);
+
+                    // output
+                    output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, TENSOR_NOT_REQUIRED);
+
+                    output        = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    if (!output) {
+                        output    = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // needs to be on GPU
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, TENSOR_NOT_REQUIRED);
+
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
+
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, TENSOR_NOT_REQUIRED);
+
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, TENSOR_NOT_REQUIRED);
+
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, TENSOR_NOT_REQUIRED);
+
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, TENSOR_NOT_REQUIRED);
+
+                        layer.attn_q_norm   = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
+                        layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias",   i), {n_embd}, TENSOR_NOT_REQUIRED);
+
+                        layer.attn_k_norm   = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
+                        layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias",   i), {n_embd}, TENSOR_NOT_REQUIRED);
+
+                        // AWQ ScaleActivation layer
+                        layer.ffn_act = create_tensor(tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, TENSOR_NOT_REQUIRED);
+                    }
+                } break;
+            case LLM_ARCH_STABLELM:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm =   create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        // optional bias tensors, present in Stable LM 2 1.6B
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     TENSOR_NOT_REQUIRED);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
+
+                        // optional q and k layernorms, present in StableLM 2 12B
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head},    TENSOR_NOT_REQUIRED);
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED);
+
+                        // optional FFN norm, not present in StableLM 2 12B which uses parallel residual
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, TENSOR_NOT_REQUIRED);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_QWEN:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}, 0);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd*3}, 0);
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff/2}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff/2, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff/2}, 0);
+                    }
+                } break;
+            case LLM_ARCH_QWEN2:
+            case LLM_ARCH_QWEN2VL:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        // optional bias tensors
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd}, 0);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, 0);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_QWEN2MOE:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        // optional bias tensors
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd}, 0);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, 0);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+
+                        if (n_expert == 0) {
+                            throw std::runtime_error("n_expert must be > 0 for QWEN2MOE");
+                        }
+                        if (n_expert_used == 0) {
+                            throw std::runtime_error("n_expert_used must be > 0 for QWEN2MOE");
+                        }
+
+                        // MoE branch
+                        const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
+
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+
+                        // Shared expert branch
+                        const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;
+
+                        layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {    n_embd, n_ff_shexp}, 0);
+                        layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp,     n_embd}, 0);
+                        layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {    n_embd, n_ff_shexp}, 0);
+                    }
+                } break;
+            case LLM_ARCH_PHI2:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+                    output_b      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "bias"),   {n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
+
+                        if (layer.wqkv == nullptr) {
+                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
+                            layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i),   {n_embd}, 0);
+
+                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
+                            layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i),   {n_embd_gqa}, 0);
+
+                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
+                            layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i),   {n_embd_gqa}, 0);
+                        }
+
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_PHI3:
+                {
+                    const int64_t n_embd_head = n_embd / n_head;
+
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
+
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, TENSOR_NOT_REQUIRED);
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
+
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
+                        layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0);
+
+                        layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+                        layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+                    }
+                } break;
+            case LLM_ARCH_PLAMO:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_GPT2:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+                    pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train}, 0);
+
+                    // output
+                    output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "bias", i),   {n_embd}, 0);
+
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
+
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_CODESHELL:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
+
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i),   {n_embd, n_ff}, 0);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i),     {n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_ORION:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_INTERNLM2:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        // layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_GEMMA:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                    }
+                } break;
+            case LLM_ARCH_GEMMA2:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
+                    }
+                } break;
+            case LLM_ARCH_STARCODER2:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        // optional bias tensors
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd}, 0);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, 0);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, 0);
+                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
+
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+
+                        // optional bias tensors
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP ,  "bias", i), {  n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_MAMBA:
+                {
+                    const int64_t d_conv  = hparams.ssm_d_conv;
+                    const int64_t d_inner = hparams.ssm_d_inner;
+                    const int64_t d_state = hparams.ssm_d_state;
+                    const int64_t dt_rank = hparams.ssm_dt_rank;
+
+                    // only an expansion factor of 2 is supported for now
+                    if (2 * n_embd != d_inner) {
+                        throw std::runtime_error("only an expansion factor of 2 is supported for now");
+                    }
+
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed, duplicated to allow offloading
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        // norm
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0);
+
+                        layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0);
+                        layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0);
+
+                        layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0);
+
+                        layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0);
+                        layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0);
+
+                        // no "weight" suffix for these
+                        layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0);
+                        layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0);
+
+                        // out_proj
+                        layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
+                    }
+                } break;
+            case LLM_ARCH_XVERSE:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_COMMAND_R:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    // init output from the input tok embed
+                    output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        if (n_layer >= 64){
+                            layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0);
+                            layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0);
+                        }
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_COHERE2:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
+                    // init output from the input tok embed
+                    output      = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab },
+                                                      TENSOR_DUPLICATED);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
+                    }
+                }
+                break;
+            case LLM_ARCH_OLMO:  // adapted from LLM_ARCH_LLAMA with norm params removed
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_OLMO2:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
+                    }
+                } break;
+            case LLM_ARCH_OLMOE:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+
+                        if (n_expert == 0) {
+                            throw std::runtime_error("n_expert must be > 0");
+                        }
+                        if (n_expert_used == 0) {
+                            throw std::runtime_error("n_expert_used must be > 0");
+                        }
+
+                        // MoE branch
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff,   n_expert}, 0);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff,   n_embd, n_expert}, 0);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd, n_ff,   n_expert}, 0);
+                    }
+                } break;
+            case LLM_ARCH_OPENELM:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    // init output from the input tok embed
+                    output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        const int64_t n_head      =   hparams.n_head(i);
+                        const int64_t n_head_qkv  = 2*hparams.n_head_kv(i) + n_head;
+                        const int64_t n_ff        =   hparams.n_ff(i);
+
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head_k}, 0);
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head_k, n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_GPTNEOX:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
+
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_ARCTIC:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+                        layer.ffn_norm_exps = create_tensor(tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd,   n_ff, n_expert}, false);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert}, 0);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert}, 0);
+                    }
+                } break;
+            case LLM_ARCH_DEEPSEEK:
+                {
+
+                    const int64_t n_ff_exp        = hparams.n_ff_exp;
+                    const int64_t n_expert_shared = hparams.n_expert_shared;
+
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        if (i < (int) hparams.n_layer_dense_lead) {
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        } else {
+                            layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+
+                            if (n_expert == 0) {
+                                throw std::runtime_error("n_expert must be > 0");
+                            }
+                            if (n_expert_used == 0) {
+                                throw std::runtime_error("n_expert_used must be > 0");
+                            }
+
+                            // MoE branch
+                            layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+                            layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
+                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+
+                            // Shared expert branch
+                            layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
+                            layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {        n_ff_exp * n_expert_shared, n_embd}, 0);
+                            layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
+                        }
+                    }
+                } break;
+            case LLM_ARCH_DEEPSEEK2:
+                {
+                    const bool is_lite = (hparams.n_layer == 27);
+
+                    const int64_t n_embd_head_qk_rope = hparams.n_rot;
+                    const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+
+                    const int64_t q_lora_rank  = hparams.n_lora_q;
+                    const int64_t kv_lora_rank = hparams.n_lora_kv;
+
+                    const int64_t n_ff_exp        = hparams.n_ff_exp;
+                    const int64_t n_expert_shared = hparams.n_expert_shared;
+
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        if (!is_lite) {
+                            layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0);
+                        }
+
+                        layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0);
+
+                        if (!is_lite) {
+                            layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0);
+                            layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0);
+                        } else {
+                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        }
+
+                        layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
+                        layer.wkv_b     = create_tensor(tn(LLM_TENSOR_ATTN_KV_B,     "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
+                        layer.wo        = create_tensor(tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {              n_head * (                      n_embd_head_v), n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        if (i < (int) hparams.n_layer_dense_lead) {
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        } else {
+                            layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+                            layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
+
+                            if (n_expert == 0) {
+                                throw std::runtime_error("n_expert must be > 0");
+                            }
+                            if (n_expert_used == 0) {
+                                throw std::runtime_error("n_expert_used must be > 0");
+                            }
+
+                            // MoE branch
+                            layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+                            layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
+                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+
+                            // Shared expert branch
+                            layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
+                            layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {        n_ff_exp * n_expert_shared, n_embd}, 0);
+                            layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
+                        }
+                    }
+                } break;
+            case LLM_ARCH_BITNET:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm     = create_tensor(tn(LLM_TENSOR_ATTN_NORM,     "weight", i), {n_embd}, 0);
+                        layer.attn_sub_norm = create_tensor(tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq       = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wq_scale = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "scale",  i), {1}, TENSOR_NOT_REQUIRED);
+                        layer.wk       = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wk_scale = create_tensor(tn(LLM_TENSOR_ATTN_K,   "scale",  i), {1}, TENSOR_NOT_REQUIRED);
+                        layer.wv       = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv_scale = create_tensor(tn(LLM_TENSOR_ATTN_V,   "scale",  i), {1}, TENSOR_NOT_REQUIRED);
+                        layer.wo       = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.wo_scale = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale",  i), {1}, TENSOR_NOT_REQUIRED);
+
+                        layer.ffn_norm     = create_tensor(tn(LLM_TENSOR_FFN_NORM,     "weight", i), {n_embd}, 0);
+                        layer.ffn_sub_norm = create_tensor(tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}, 0);
+
+                        layer.ffn_gate       = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_gate_scale = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale",  i), {1}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_down       = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_scale = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale",  i), {1}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_up         = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_up_scale   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "scale",  i), {1}, TENSOR_NOT_REQUIRED);
+                    }
+                } break;
+            case LLM_ARCH_T5:
+                {
+                    const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts;
+
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm     = create_tensor(tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd}, 0);
+
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm_enc  = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM,  "weight", i), {n_embd}, 0);
+                        layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED);
+
+                        layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
+
+                        layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd,   n_ff}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up_enc   = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+
+                        layer.attn_norm  = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM,  "weight", i), {n_embd}, 0);
+                        layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_DEC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_DEC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_DEC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
+
+                        layer.attn_norm_cross  = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM,  "weight", i), {n_embd}, 0);
+                        // this tensor seems to be unused in HF transformers implementation
+                        layer.attn_rel_b_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED);
+
+                        layer.wq_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wk_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd,   n_ff}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_DEC_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_T5ENCODER:
+                {
+                    const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts;
+
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm_enc  = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM,  "weight", i), {n_embd}, 0);
+                        layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED);
+
+                        layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
+
+                        layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd,   n_ff}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up_enc   = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_JAIS:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "bias", i),   {n_embd}, 0);
+
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
+
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
+
+                        layer.ffn_gate   = create_tensor(tn(LLM_TENSOR_FFN_GATE,   "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE,   "bias", i),   {n_ff}, 0);
+
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_CHATGLM:
+                {
+                    tok_embd   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
+
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff * 2}, 0);
+
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                    }
+                } break;
+            case LLM_ARCH_NEMOTRON:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
+                    output        = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        // optional bias tensors
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     TENSOR_NOT_REQUIRED);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
+                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     TENSOR_NOT_REQUIRED);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0);
+
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+
+                        // optional MLP bias
+                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, TENSOR_NOT_REQUIRED);
+                    }
+                } break;
+            case LLM_ARCH_EXAONE:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM,   "weight", i), {n_embd}, 0);
+                        layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+                        layer.ffn_gate   = create_tensor(tn(LLM_TENSOR_FFN_GATE,   "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN,   "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,     "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_RWKV6:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // Block 0, LN0
+                    tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
+                    tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
+
+                    const int time_mix_extra_dim = hparams.time_mix_extra_dim;
+                    const int time_decay_extra_dim = hparams.time_decay_extra_dim;
+                    const int head_size = hparams.wkv_head_size;
+                    const int attn_hidden_size = n_embd;
+                    const int ffn_size = hparams.n_ff_arr[0];
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
+
+                        layer.attn_norm_2   = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0);
+                        layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i),   {n_embd}, 0);
+
+                        layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0);
+                        layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
+
+                        layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
+                        layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        GGML_ASSERT(!(layer.time_mix_lerp_fused == NULL && layer.time_mix_lerp_w == NULL));
+
+                        layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0);
+                        layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
+                        layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0);
+                        layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0);
+                        layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0);
+                        layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0);
+                        layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
+                        layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0);
+
+                        layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0);
+                        layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0);
+                        layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
+
+                        layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
+                        layer.channel_mix_lerp_r = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0);
+
+                        layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0);
+                        layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0);
+                        layer.channel_mix_receptance = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "weight", i), {n_embd, n_embd}, 0);
+                    }
+
+                } break;
+            case LLM_ARCH_RWKV6QWEN2:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
+
+                    const int time_mix_extra_dim = hparams.time_mix_extra_dim;
+                    const int time_decay_extra_dim = hparams.time_decay_extra_dim;
+                    const int head_size = hparams.wkv_head_size;
+                    const int attn_hidden_size = n_embd;
+                    const int n_head_kv = hparams.n_head_kv();
+                    int attn_key_value_size;
+                    if (n_head_kv == 0 || attn_hidden_size / head_size == n_head_kv) {
+                        attn_key_value_size = attn_hidden_size;
+                    } else {
+                        attn_key_value_size = n_head_kv * head_size;
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0);
+                        layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
+
+                        layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
+                        layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0);
+
+                        layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
+                        layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0);
+                        layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0);
+                        layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {n_embd, attn_key_value_size}, 0);
+                        layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {n_embd, attn_key_value_size}, 0);
+                        layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
+                        layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0);
+                        // optional bias tensors
+                        layer.time_mix_key_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "bias", i), {attn_key_value_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_value_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "bias", i), {attn_key_value_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.time_mix_receptance_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "bias", i), {attn_hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED);
+
+                        layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_CHAMELEON:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0);
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0);
+                        layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i),  {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED);
+                        layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i),  {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_WAVTOKENIZER_DEC:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd_features, n_vocab}, 0);
+
+                    conv1d   = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd_features, hparams.posnet.n_embd}, 0);
+                    conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"),   {1, hparams.posnet.n_embd}, 0);
+
+                    // posnet
+                    {
+                        const int64_t n_embd = hparams.posnet.n_embd;
+
+                        for (uint32_t i = 0; i < hparams.posnet.n_layer; ++i) {
+                            auto & layer = layers[i].posnet;
+
+                            // posnet:
+                            //
+                            //  - resnet
+                            //  - resnet
+                            //  - attn
+                            //  - resnet
+                            //  - resnet
+                            //  - norm
+                            //
+                            switch (i) {
+                                case 0:
+                                case 1:
+                                case 3:
+                                case 4:
+                                    {
+                                        layer.norm1   = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", i), {1, n_embd}, 0);
+                                        layer.norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias",   i), {1, n_embd}, 0);
+
+                                        layer.conv1   = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "weight", i), {3, n_embd, n_embd}, 0);
+                                        layer.conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias",   i), {1, n_embd}, 0);
+
+                                        layer.norm2   = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", i), {1, n_embd}, 0);
+                                        layer.norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias",   i), {1, n_embd}, 0);
+
+                                        layer.conv2   = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "weight", i), {3, n_embd, n_embd}, 0);
+                                        layer.conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias",   i), {1, n_embd}, 0);
+                                    } break;
+                                case 2:
+                                    {
+                                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0);
+                                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias",   i), {1, n_embd}, 0);
+
+                                        layer.attn_q      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q,    "weight", i), {1, n_embd, n_embd}, 0);
+                                        layer.attn_q_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q,    "bias",   i), {1, n_embd}, 0);
+
+                                        layer.attn_k      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K,    "weight", i), {1, n_embd, n_embd}, 0);
+                                        layer.attn_k_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K,    "bias",   i), {1, n_embd}, 0);
+
+                                        layer.attn_v      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V,    "weight", i), {1, n_embd, n_embd}, 0);
+                                        layer.attn_v_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V,    "bias",   i), {1, n_embd}, 0);
+
+                                        layer.attn_o      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT,  "weight", i), {1, n_embd, n_embd}, 0);
+                                        layer.attn_o_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT,  "bias",   i), {1, n_embd}, 0);
+                                    } break;
+                                case 5:
+                                    {
+                                        layer.norm   = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0);
+                                        layer.norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias",   i), {1, n_embd}, 0);
+                                    } break;
+                                default: GGML_ABORT("unknown posnet layer");
+                            };
+                        }
+                    }
+
+                    GGML_ASSERT(hparams.posnet.n_embd == hparams.convnext.n_embd);
+
+                    tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {hparams.posnet.n_embd}, 0);
+                    tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {hparams.posnet.n_embd}, 0);
+
+                    // convnext
+                    {
+                        const int64_t n_embd = hparams.convnext.n_embd;
+
+                        for (uint32_t i = 0; i < hparams.convnext.n_layer; ++i) {
+                            auto & layer = layers[i].convnext;
+
+                            layer.dw     = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW,    "weight", i), {7, 1, n_embd}, 0);
+                            layer.dw_b   = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW,    "bias",   i), {1, n_embd}, 0);
+
+                            layer.norm   = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM,  "weight", i), {n_embd}, 0);
+                            layer.norm_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM,  "bias",   i), {n_embd}, 0);
+
+                            layer.pw1    = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1,   "weight", i), {n_embd, n_ff}, 0);
+                            layer.pw1_b  = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1,   "bias",   i), {n_ff}, 0);
+
+                            layer.pw2    = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2,   "weight", i), {n_ff, n_embd}, 0);
+                            layer.pw2_b  = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2,   "bias",   i), {n_embd}, 0);
+
+                            layer.gamma  = create_tensor(tn(LLM_TENSOR_CONVNEXT_GAMMA, "weight", i), {n_embd}, 0);
+                        }
+
+                        // output
+                        output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                        output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    }
+
+                    output   = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0);
+                    output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"),   {n_embd}, 0);
+                } break;
+            default:
+                throw std::runtime_error("unknown architecture");
+        }
+
+        if (n_moved_tensors > 0) {
+            LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %d others) cannot be used with preferred buffer type %s, using %s instead\n",
+                __func__, first_moved_tensor->name, ggml_type_name(first_moved_tensor->type), n_moved_tensors - 1,
+                ggml_backend_buft_name(first_moved_from_buft), ggml_backend_buft_name(first_moved_to_buft));
+        }
+    }
+
+    ml.done_getting_tensors();
+
+    ml.init_mappings(true, use_mlock ? &pimpl->mlock_mmaps : nullptr);
+    pimpl->mappings.reserve(ml.mappings.size());
+
+    // create the backend buffers
+    std::vector> ctx_bufs;
+    ctx_bufs.reserve(ctx_map.size());
+
+    // Ensure we have enough capacity for the maximum backend buffer we will potentially create
+    const size_t n_max_backend_buffer = ctx_map.size() * ml.files.size();
+    pimpl->bufs.reserve(n_max_backend_buffer);
+
+    for (auto & it : ctx_map) {
+        ggml_backend_buffer_type_t buft = it.first;
+        ggml_context * ctx              = it.second;
+
+        // skip contexts without tensors
+        if (ggml_get_first_tensor(ctx) == nullptr) {
+            continue;
+        }
+
+        llama_buf_map buf_map;
+        buf_map.reserve(n_max_backend_buffer);
+
+        // check if it is possible to use buffer_from_host_ptr with this buffer type
+        ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft);
+        if (!dev) {
+            // FIXME: workaround for CPU backend buft having a NULL device
+            dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
+        }
+        ggml_backend_dev_props props;
+        ggml_backend_dev_get_props(dev, &props);
+        bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr;
+        bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev);
+
+        if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) {
+            for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
+                // only the mmap region containing the tensors in the model is mapped to the backend buffer
+                // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, then we could just use metal for all layers
+                // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size
+                void * addr = nullptr;
+                size_t first, last; // NOLINT
+                ml.get_mapping_range(&first, &last, &addr, idx, ctx);
+                if (first >= last) {
+                    continue;
+                }
+                const size_t max_size = ggml_get_max_tensor_size(ctx);
+                ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size);
+                if (buf == nullptr) {
+                    throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
+                }
+                pimpl->bufs.emplace_back(buf);
+                buf_map.emplace(idx, buf);
+            }
+        }
+        else {
+            ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
+            if (buf == nullptr) {
+                throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
+            }
+            pimpl->bufs.emplace_back(buf);
+            if (use_mlock && ggml_backend_buffer_is_host(buf)) {
+                pimpl->mlock_bufs.emplace_back(new llama_mlock);
+                auto & mlock_buf = pimpl->mlock_bufs.back();
+                mlock_buf->init   (ggml_backend_buffer_get_base(buf));
+                mlock_buf->grow_to(ggml_backend_buffer_get_size(buf));
+            }
+            for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
+                buf_map.emplace(idx, buf);
+            }
+        }
+
+        if (pimpl->bufs.empty()) {
+            throw std::runtime_error("failed to allocate buffer");
+        }
+
+        for (auto & buf : buf_map) {
+            // indicate that this buffer contains weights
+            // this is used by ggml_backend_sched to improve op scheduling: ops that use a weight are preferably scheduled to the backend that contains the weight
+            ggml_backend_buffer_set_usage(buf.second, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
+        }
+
+        ctx_bufs.emplace_back(ctx, buf_map);
+    }
+
+    if (llama_supports_gpu_offload()) {
+        const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
+
+        LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu);
+        if (n_gpu_layers > (int) hparams.n_layer) {
+            LLAMA_LOG_INFO("%s: offloading output layer to GPU\n", __func__);
+        }
+
+        const int max_backend_supported_layers = hparams.n_layer + 1;
+        const int max_offloadable_layers       = hparams.n_layer + 1;
+
+        LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
+    }
+
+    // print memory requirements per buffer type
+    for (auto & buf : pimpl->bufs) {
+        LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0);
+    }
+
+    // populate tensors_by_name
+    for (auto & ctx : pimpl->ctxs) {
+        for (auto * cur = ggml_get_first_tensor(ctx.get()); cur != NULL; cur = ggml_get_next_tensor(ctx.get(), cur)) {
+            tensors_by_name.emplace_back(ggml_get_name(cur), cur);
+        }
+    }
+
+    // load tensor data
+    for (auto & it : ctx_bufs) {
+        ggml_context * ctx = it.first;
+        auto & bufs = it.second;
+        if (!ml.load_all_data(ctx, bufs, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) {
+            return false;
+        }
+    }
+
+    if (use_mmap_buffer) {
+        for (auto & mapping : ml.mappings) {
+            pimpl->mappings.emplace_back(std::move(mapping));
+        }
+    }
+
+    return true;
+}
+
+std::string llama_model::arch_name() const {
+    return llm_arch_name(arch);
+}
+
+std::string llama_model::type_name() const {
+    return llm_type_name(type);
+}
+
+std::string llama_model::desc() const {
+    return pimpl->desc_str;
+}
+
+size_t llama_model::size() const {
+    return pimpl->n_bytes;
+}
+
+size_t llama_model::max_nodes() const {
+    return std::max(8192, tensors_by_name.size()*5);
+}
+
+size_t llama_model::n_devices() const {
+    return devices.size();
+}
+
+uint64_t llama_model::n_elements() const {
+    return pimpl->n_elements;
+}
+
+void llama_model::print_info() const {
+    const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
+
+    auto print_f = [](const std::function & f, uint32_t n) {
+        bool is_var = false;
+
+        std::vector v;
+        for (uint32_t i = 0; i < n; ++i) {
+            v.push_back(f(i));
+            if (v[i] != v[0]) {
+                is_var = true;
+            }
+        }
+
+        std::stringstream ss;
+
+        if (is_var) {
+            ss << "[";
+            for (uint32_t i = 0; i < n; ++i) {
+                ss << v[i];
+                if (i < n - 1) {
+                    ss << ", ";
                 }
             }
             ss << "]";
@@ -1859,11 +3550,7 @@ void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     };
 
     // hparams
-    LLAMA_LOG_INFO("%s: format           = %s\n",     __func__, llama_file_version_name(ml.fver));
-    LLAMA_LOG_INFO("%s: arch             = %s\n",     __func__, llm_arch_name(model.arch));
-    LLAMA_LOG_INFO("%s: vocab type       = %s\n",     __func__, llama_model_vocab_type_name(vocab.type));
-    LLAMA_LOG_INFO("%s: n_vocab          = %u\n",     __func__, hparams.n_vocab);
-    LLAMA_LOG_INFO("%s: n_merges         = %u\n",     __func__, (int) vocab.bpe_ranks.size());
+    LLAMA_LOG_INFO("%s: arch             = %s\n",     __func__, arch_name().c_str());
     LLAMA_LOG_INFO("%s: vocab_only       = %d\n",     __func__, hparams.vocab_only);
 
     if (!hparams.vocab_only) {
@@ -1902,60 +3589,28 @@ void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
         LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms   = %d\n",     __func__, hparams.ssm_dt_b_c_rms);
     }
 
-    LLAMA_LOG_INFO("%s: model type       = %s\n",     __func__, llama_model_type_name(model).c_str());
-    LLAMA_LOG_INFO("%s: model ftype      = %s\n",     __func__, llama_model_ftype_name(model).c_str());
-    if (ml.n_elements >= 1e12) {
-        LLAMA_LOG_INFO("%s: model params     = %.2f T\n", __func__, ml.n_elements*1e-12);
-    } else if (ml.n_elements >= 1e9) {
-        LLAMA_LOG_INFO("%s: model params     = %.2f B\n", __func__, ml.n_elements*1e-9);
-    } else if (ml.n_elements >= 1e6) {
-        LLAMA_LOG_INFO("%s: model params     = %.2f M\n", __func__, ml.n_elements*1e-6);
+    LLAMA_LOG_INFO("%s: model type       = %s\n",     __func__, type_name().c_str());
+    if (pimpl->n_elements >= 1e12) {
+        LLAMA_LOG_INFO("%s: model params     = %.2f T\n", __func__, pimpl->n_elements*1e-12);
+    } else if (pimpl->n_elements >= 1e9) {
+        LLAMA_LOG_INFO("%s: model params     = %.2f B\n", __func__, pimpl->n_elements*1e-9);
+    } else if (pimpl->n_elements >= 1e6) {
+        LLAMA_LOG_INFO("%s: model params     = %.2f M\n", __func__, pimpl->n_elements*1e-6);
     } else {
-        LLAMA_LOG_INFO("%s: model params     = %.2f K\n", __func__, ml.n_elements*1e-3);
-    }
-    if (ml.n_bytes < GiB) {
-        LLAMA_LOG_INFO("%s: model size       = %.2f MiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0,        ml.n_bytes*8.0/ml.n_elements);
-    } else {
-        LLAMA_LOG_INFO("%s: model size       = %.2f GiB (%.2f BPW) \n", __func__, ml.n_bytes/1024.0/1024.0/1024.0, ml.n_bytes*8.0/ml.n_elements);
+        LLAMA_LOG_INFO("%s: model params     = %.2f K\n", __func__, pimpl->n_elements*1e-3);
     }
 
     // general kv
-    LLAMA_LOG_INFO("%s: general.name     = %s\n",    __func__, model.name.c_str());
-
-    // special tokens
-    if (vocab.special_bos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, vocab.special_bos_id,     vocab.id_to_token[vocab.special_bos_id].text.c_str() );  }
-    if (vocab.special_eos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, vocab.special_eos_id,     vocab.id_to_token[vocab.special_eos_id].text.c_str() );  }
-    if (vocab.special_eot_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, vocab.special_eot_id,     vocab.id_to_token[vocab.special_eot_id].text.c_str() );  }
-    if (vocab.special_eom_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOM token        = %d '%s'\n", __func__, vocab.special_eom_id,     vocab.id_to_token[vocab.special_eom_id].text.c_str() );  }
-    if (vocab.special_unk_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, vocab.special_unk_id,     vocab.id_to_token[vocab.special_unk_id].text.c_str() );  }
-    if (vocab.special_sep_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, vocab.special_sep_id,     vocab.id_to_token[vocab.special_sep_id].text.c_str() );  }
-    if (vocab.special_pad_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, vocab.special_pad_id,     vocab.id_to_token[vocab.special_pad_id].text.c_str() );  }
-    if (vocab.special_cls_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: CLS token        = %d '%s'\n", __func__, vocab.special_cls_id,     vocab.id_to_token[vocab.special_cls_id].text.c_str() );  }
-    if (vocab.special_mask_id != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, vocab.special_mask_id,    vocab.id_to_token[vocab.special_mask_id].text.c_str() ); }
-
-    if (vocab.linefeed_id != LLAMA_TOKEN_NULL)        { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, vocab.linefeed_id,        vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
-
-    if (vocab.special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token    = %d '%s'\n", __func__, vocab.special_fim_pre_id, vocab.id_to_token[vocab.special_fim_pre_id].text.c_str() ); }
-    if (vocab.special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token    = %d '%s'\n", __func__, vocab.special_fim_suf_id, vocab.id_to_token[vocab.special_fim_suf_id].text.c_str() ); }
-    if (vocab.special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token    = %d '%s'\n", __func__, vocab.special_fim_mid_id, vocab.id_to_token[vocab.special_fim_mid_id].text.c_str() ); }
-    if (vocab.special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token    = %d '%s'\n", __func__, vocab.special_fim_pad_id, vocab.id_to_token[vocab.special_fim_pad_id].text.c_str() ); }
-    if (vocab.special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token    = %d '%s'\n", __func__, vocab.special_fim_rep_id, vocab.id_to_token[vocab.special_fim_rep_id].text.c_str() ); }
-    if (vocab.special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token    = %d '%s'\n", __func__, vocab.special_fim_sep_id, vocab.id_to_token[vocab.special_fim_sep_id].text.c_str() ); }
-
-    for (const auto & id : vocab.special_eog_ids) {
-        LLAMA_LOG_INFO( "%s: EOG token        = %d '%s'\n", __func__, id, vocab.id_to_token[id].text.c_str() );
-    }
-
-    LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len);
+    LLAMA_LOG_INFO("%s: general.name     = %s\n",    __func__, name.c_str());
 
-    if (model.arch == LLM_ARCH_DEEPSEEK) {
+    if (arch == LLM_ARCH_DEEPSEEK) {
         LLAMA_LOG_INFO("%s: n_layer_dense_lead   = %d\n",     __func__, hparams.n_layer_dense_lead);
         LLAMA_LOG_INFO("%s: n_ff_exp             = %d\n",     __func__, hparams.n_ff_exp);
         LLAMA_LOG_INFO("%s: n_expert_shared      = %d\n",     __func__, hparams.n_expert_shared);
         LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n",   __func__, hparams.expert_weights_scale);
     }
 
-    if (model.arch == LLM_ARCH_DEEPSEEK2) {
+    if (arch == LLM_ARCH_DEEPSEEK2) {
         LLAMA_LOG_INFO("%s: n_layer_dense_lead   = %d\n",     __func__, hparams.n_layer_dense_lead);
         LLAMA_LOG_INFO("%s: n_lora_q             = %d\n",     __func__, hparams.n_lora_q);
         LLAMA_LOG_INFO("%s: n_lora_kv            = %d\n",     __func__, hparams.n_lora_kv);
@@ -1967,16 +3622,88 @@ void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
         LLAMA_LOG_INFO("%s: rope_yarn_log_mul    = %.4f\n",   __func__, hparams.rope_yarn_log_mul);
     }
 
-    if (model.arch == LLM_ARCH_QWEN2MOE) {
+    if (arch == LLM_ARCH_QWEN2MOE) {
         LLAMA_LOG_INFO("%s: n_ff_exp         = %d\n",     __func__, hparams.n_ff_exp);
         LLAMA_LOG_INFO("%s: n_ff_shexp       = %d\n",     __func__, hparams.n_ff_shexp);
     }
 
-    if (model.arch == LLM_ARCH_MINICPM || model.arch == LLM_ARCH_GRANITE || model.arch == LLM_ARCH_GRANITE_MOE) {
+    if (arch == LLM_ARCH_MINICPM || arch == LLM_ARCH_GRANITE || arch == LLM_ARCH_GRANITE_MOE) {
         LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
         LLAMA_LOG_INFO("%s: f_residual_scale  = %f\n", __func__, hparams.f_residual_scale);
         LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
     }
+
+    vocab.print_info();
+}
+
+ggml_backend_dev_t llama_model::dev_layer(int il) const {
+    return pimpl->dev_layer.at(il).dev;
+}
+
+ggml_backend_dev_t llama_model::dev_output() const {
+    return pimpl->dev_output.dev;
+}
+
+template
+static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) {
+    ggml_init_params params = {
+        /*.mem_size   =*/ ggml_tensor_overhead()*8,
+        /*.mem_buffer =*/ NULL,
+        /*.no_alloc   =*/ true,
+    };
+
+    ggml_context_ptr ctx { ggml_init(params) };
+    if (!ctx) {
+        throw std::runtime_error(format("failed to create ggml context"));
+    }
+
+    ggml_backend_buffer_ptr buf { ggml_backend_buft_alloc_buffer(buft, 0) };
+    ggml_tensor * op_tensor = fn(ctx.get());
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        if (op_tensor->src[i] != nullptr) {
+            assert(op_tensor->src[i]->buffer == nullptr);
+            op_tensor->src[i]->buffer = buf.get();
+        }
+    }
+
+    bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
+
+    return op_supported;
+}
+
+template
+static ggml_backend_buffer_type_t select_buft(const buft_list_t & buft_list, const F & fn) {
+    for (const auto & cur : buft_list) {
+        ggml_backend_dev_t cur_dev = cur.first;
+        ggml_backend_buffer_type_t cur_buft = cur.second;
+        if (buft_supported(cur_buft, cur_dev, fn)) {
+            return cur_buft;
+        }
+    }
+
+    throw std::runtime_error(format("no suitable buffer type found"));
+}
+
+ggml_backend_buffer_type_t llama_model::select_buft(int il) const {
+    return ::select_buft(
+            *pimpl->dev_layer.at(il).buft_list,
+            [&](ggml_context * ctx) {
+                ggml_tensor * cur = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd);
+                ggml_tensor * layer_dir = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd);
+                return ggml_add(ctx, cur, layer_dir);
+            });
+}
+
+const struct ggml_tensor * llama_model::get_tensor(const char * name) const {
+    auto it = std::find_if(tensors_by_name.begin(), tensors_by_name.end(),
+            [name](const std::pair & it) {
+                return it.first == name;
+            });
+    if (it == tensors_by_name.end()) {
+        return nullptr;
+    }
+
+    return it->second;
 }
 
 //
@@ -2008,6 +3735,10 @@ struct llama_model_params llama_model_default_params() {
     return result;
 }
 
+const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model) {
+    return &model->vocab;
+}
+
 void llama_free_model(struct llama_model * model) {
     llama_model_free(model);
 }
@@ -2016,31 +3747,43 @@ void llama_model_free(struct llama_model * model) {
     delete model;
 }
 
-enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
-    return model->vocab.type;
+int32_t llama_model_n_ctx_train(const struct llama_model * model) {
+    return model->hparams.n_ctx_train;
+}
+
+int32_t llama_model_n_embd(const struct llama_model * model) {
+    return model->hparams.n_embd;
+}
+
+int32_t llama_model_n_layer(const struct llama_model * model) {
+    return model->hparams.n_layer;
 }
 
-int32_t llama_n_vocab(const struct llama_model * model) {
-    return model->hparams.n_vocab;
+int32_t llama_model_n_head(const struct llama_model * model) {
+    return model->hparams.n_head();
 }
 
+// deprecated
 int32_t llama_n_ctx_train(const struct llama_model * model) {
-    return model->hparams.n_ctx_train;
+    return llama_model_n_ctx_train(model);
 }
 
+// deprecated
 int32_t llama_n_embd(const struct llama_model * model) {
-    return model->hparams.n_embd;
+    return llama_model_n_embd(model);
 }
 
+// deprecated
 int32_t llama_n_layer(const struct llama_model * model) {
-    return model->hparams.n_layer;
+    return llama_model_n_layer(model);
 }
 
+// deprecated
 int32_t llama_n_head(const struct llama_model * model) {
-    return model->hparams.n_head();
+    return llama_model_n_head(model);
 }
 
-enum llama_rope_type llama_rope_type(const struct llama_model * model) {
+enum llama_rope_type llama_model_rope_type(const struct llama_model * model) {
     switch (model->arch) {
         // these models do not use RoPE
         case LLM_ARCH_GPT2:
@@ -2054,6 +3797,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_T5ENCODER:
         case LLM_ARCH_JAIS:
         case LLM_ARCH_RWKV6:
+        case LLM_ARCH_RWKV6QWEN2:
         case LLM_ARCH_WAVTOKENIZER_DEC:
             return LLAMA_ROPE_TYPE_NONE;
 
@@ -2094,6 +3838,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_OLMOE:
         case LLM_ARCH_PHI2:
         case LLM_ARCH_PHI3:
+        case LLM_ARCH_PHIMOE:
         case LLM_ARCH_GEMMA:
         case LLM_ARCH_GEMMA2:
         case LLM_ARCH_STARCODER2:
@@ -2116,7 +3861,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
     return LLAMA_ROPE_TYPE_NONE;
 }
 
-float llama_rope_freq_scale_train(const struct llama_model * model) {
+float llama_model_rope_freq_scale_train(const struct llama_model * model) {
     return model->hparams.rope_freq_scale_train;
 }
 
@@ -2160,18 +3905,24 @@ int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int3
 }
 
 int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) {
-    return snprintf(buf, buf_size, "%s %s %s",
-            llama_model_arch_name (*model).c_str(),
-            llama_model_type_name (*model).c_str(),
-            llama_model_ftype_name(*model).c_str());
+    return snprintf(buf, buf_size, "%s", model->desc().c_str());
 }
 
 uint64_t llama_model_size(const struct llama_model * model) {
-    return model->n_bytes;
+    return model->size();
+}
+
+const char * llama_model_chat_template(const struct llama_model * model) {
+    const auto & it = model->gguf_kv.find(LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE));
+    if (it == model->gguf_kv.end()) {
+        return nullptr;
+    }
+
+    return it->second.c_str();
 }
 
 uint64_t llama_model_n_params(const struct llama_model * model) {
-    return model->n_elements;
+    return model->n_elements();
 }
 
 bool llama_model_has_encoder(const struct llama_model * model) {
@@ -2197,6 +3948,7 @@ bool llama_model_is_recurrent(const struct llama_model * model) {
     switch (model->arch) {
         case LLM_ARCH_MAMBA:  return true;
         case LLM_ARCH_RWKV6:  return true;
+        case LLM_ARCH_RWKV6QWEN2: return true;
         default:              return false;
     }
 }
diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h
index ce038932d4e..4cc8abb753a 100644
--- a/examples/talk-llama/llama-model.h
+++ b/examples/talk-llama/llama-model.h
@@ -4,78 +4,80 @@
 #include "llama-arch.h"
 #include "llama-hparams.h"
 #include "llama-vocab.h"
-#include "llama-mmap.h"
-
-#include "ggml-cpp.h"
 
+#include 
+#include 
+#include 
 #include 
 
+struct llama_model_loader;
+
 // available models
-// TODO: this enum does not follow the enum naming convention
 enum llm_type {
-    MODEL_UNKNOWN,
-    MODEL_14M,
-    MODEL_17M,
-    MODEL_22M,
-    MODEL_33M,
-    MODEL_60M,
-    MODEL_70M,
-    MODEL_80M,
-    MODEL_109M,
-    MODEL_137M,
-    MODEL_160M,
-    MODEL_220M,
-    MODEL_250M,
-    MODEL_270M,
-    MODEL_335M,
-    MODEL_410M,
-    MODEL_450M,
-    MODEL_770M,
-    MODEL_780M,
-    MODEL_0_5B,
-    MODEL_1B,
-    MODEL_1_3B,
-    MODEL_1_4B,
-    MODEL_1_5B,
-    MODEL_1_6B,
-    MODEL_2B,
-    MODEL_2_8B,
-    MODEL_3B,
-    MODEL_4B,
-    MODEL_6B,
-    MODEL_6_9B,
-    MODEL_7B,
-    MODEL_8B,
-    MODEL_9B,
-    MODEL_11B,
-    MODEL_12B,
-    MODEL_13B,
-    MODEL_14B,
-    MODEL_15B,
-    MODEL_16B,
-    MODEL_20B,
-    MODEL_30B,
-    MODEL_32B,
-    MODEL_34B,
-    MODEL_35B,
-    MODEL_40B,
-    MODEL_65B,
-    MODEL_70B,
-    MODEL_236B,
-    MODEL_314B,
-    MODEL_671B,
-    MODEL_SMALL,
-    MODEL_MEDIUM,
-    MODEL_LARGE,
-    MODEL_XL,
-    MODEL_A1_7B,
-    MODEL_A2_7B,
-    MODEL_8x7B,
-    MODEL_8x22B,
-    MODEL_16x12B,
-    MODEL_10B_128x3_66B,
-    MODEL_57B_A14B,
-    MODEL_27B,
+    LLM_TYPE_UNKNOWN,
+    LLM_TYPE_14M,
+    LLM_TYPE_17M,
+    LLM_TYPE_22M,
+    LLM_TYPE_33M,
+    LLM_TYPE_60M,
+    LLM_TYPE_70M,
+    LLM_TYPE_80M,
+    LLM_TYPE_109M,
+    LLM_TYPE_137M,
+    LLM_TYPE_160M,
+    LLM_TYPE_220M,
+    LLM_TYPE_250M,
+    LLM_TYPE_270M,
+    LLM_TYPE_335M,
+    LLM_TYPE_410M,
+    LLM_TYPE_450M,
+    LLM_TYPE_770M,
+    LLM_TYPE_780M,
+    LLM_TYPE_0_5B,
+    LLM_TYPE_1B,
+    LLM_TYPE_1_3B,
+    LLM_TYPE_1_4B,
+    LLM_TYPE_1_5B,
+    LLM_TYPE_1_6B,
+    LLM_TYPE_2B,
+    LLM_TYPE_2_8B,
+    LLM_TYPE_3B,
+    LLM_TYPE_4B,
+    LLM_TYPE_6B,
+    LLM_TYPE_6_9B,
+    LLM_TYPE_7B,
+    LLM_TYPE_8B,
+    LLM_TYPE_9B,
+    LLM_TYPE_11B,
+    LLM_TYPE_12B,
+    LLM_TYPE_13B,
+    LLM_TYPE_14B,
+    LLM_TYPE_15B,
+    LLM_TYPE_16B,
+    LLM_TYPE_20B,
+    LLM_TYPE_30B,
+    LLM_TYPE_32B,
+    LLM_TYPE_34B,
+    LLM_TYPE_35B,
+    LLM_TYPE_40B,
+    LLM_TYPE_65B,
+    LLM_TYPE_70B,
+    LLM_TYPE_236B,
+    LLM_TYPE_314B,
+    LLM_TYPE_671B,
+    LLM_TYPE_SMALL,
+    LLM_TYPE_MEDIUM,
+    LLM_TYPE_LARGE,
+    LLM_TYPE_XL,
+    LLM_TYPE_A1_7B,
+    LLM_TYPE_A2_7B,
+    LLM_TYPE_8x7B,
+    LLM_TYPE_8x22B,
+    LLM_TYPE_16x12B,
+    LLM_TYPE_16x3_8B,
+    LLM_TYPE_10B_128x3_66B,
+    LLM_TYPE_57B_A14B,
+    LLM_TYPE_27B,
 };
 
 struct llama_layer_posnet {
@@ -240,15 +242,19 @@ struct llama_layer {
     struct ggml_tensor * time_mix_lerp_v     = nullptr;
     struct ggml_tensor * time_mix_lerp_r     = nullptr;
     struct ggml_tensor * time_mix_lerp_g     = nullptr;
-
-    struct ggml_tensor * time_mix_first      = nullptr;
-    struct ggml_tensor * time_mix_decay      = nullptr;
-    struct ggml_tensor * time_mix_decay_w1   = nullptr;
-    struct ggml_tensor * time_mix_decay_w2   = nullptr;
-    struct ggml_tensor * time_mix_key        = nullptr;
-    struct ggml_tensor * time_mix_value      = nullptr;
-    struct ggml_tensor * time_mix_receptance = nullptr;
-    struct ggml_tensor * time_mix_gate       = nullptr;
+    struct ggml_tensor * time_mix_lerp_fused = nullptr;
+
+    struct ggml_tensor * time_mix_first        = nullptr;
+    struct ggml_tensor * time_mix_decay        = nullptr;
+    struct ggml_tensor * time_mix_decay_w1     = nullptr;
+    struct ggml_tensor * time_mix_decay_w2     = nullptr;
+    struct ggml_tensor * time_mix_key          = nullptr;
+    struct ggml_tensor * time_mix_key_b        = nullptr;
+    struct ggml_tensor * time_mix_value        = nullptr;
+    struct ggml_tensor * time_mix_value_b      = nullptr;
+    struct ggml_tensor * time_mix_receptance   = nullptr;
+    struct ggml_tensor * time_mix_receptance_b = nullptr;
+    struct ggml_tensor * time_mix_gate         = nullptr;
 
     struct ggml_tensor * time_mix_ln     = nullptr;
     struct ggml_tensor * time_mix_ln_b   = nullptr;
@@ -281,11 +287,9 @@ struct llama_layer {
 };
 
 struct llama_model {
-    llm_type type = MODEL_UNKNOWN;
+    llm_type type = LLM_TYPE_UNKNOWN;
     llm_arch arch = LLM_ARCH_UNKNOWN;
 
-    llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
-
     std::string name = "n/a";
 
     llama_hparams hparams = {};
@@ -314,78 +318,55 @@ struct llama_model {
 
     std::vector layers;
 
+    llama_model_params params;
+
     // gguf metadata
     std::unordered_map gguf_kv;
 
-    llama_split_mode split_mode;
-    int main_gpu;
-    int n_gpu_layers;
-
     std::vector rpc_servers;
 
     // list of devices used in this model
     std::vector devices;
 
-
-    // lists of buffer types used for each layer
-    using buft_list_t = std::vector>;
-    buft_list_t cpu_buft_list;
-    std::map gpu_buft_list;
-
-    struct layer_dev {
-        ggml_backend_dev_t dev;
-        buft_list_t * buft_list;
-    };
-
-    layer_dev dev_input = {};
-    layer_dev dev_output = {};
-    std::vector dev_layer;
-
-    // contexts where the model tensors metadata is stored
-    std::vector ctxs;
-
-    // the model memory buffers for the tensor data
-    std::vector bufs;
-
-    // model memory mapped files
-    llama_mmaps mappings;
-
-    // objects representing data potentially being locked in memory
-    llama_mlocks mlock_bufs;
-    llama_mlocks mlock_mmaps;
-
     // for quantize-stats only
     std::vector> tensors_by_name;
 
     int64_t t_load_us  = 0;
     int64_t t_start_us = 0;
 
-    // total number of parameters in the model
-    uint64_t n_elements = 0;
+    explicit llama_model(const struct llama_model_params & params);
+    ~llama_model();
 
-    // total size of all the tensors in the model in bytes
-    size_t  n_bytes     = 0;
-};
+    void load_stats  (llama_model_loader & ml);
+    void load_arch   (llama_model_loader & ml);
+    void load_hparams(llama_model_loader & ml);
+    void load_vocab  (llama_model_loader & ml);
+    bool load_tensors(llama_model_loader & ml); // returns false if cancelled by progress_callback
 
-const char * llm_type_name(llm_type type);
+    std::string arch_name() const;
+    std::string type_name() const;
+
+    std::string desc() const;
 
-std::string llama_model_arch_name (const llama_model & model);
-std::string llama_model_type_name (const llama_model & model);
-std::string llama_model_ftype_name(const llama_model & model);
+    size_t size() const;
+    size_t max_nodes() const;
+    size_t n_devices() const;
 
-// used by llama_adapter_cvec
-ggml_backend_buffer_type_t llama_model_select_buft(const llama_model & model, int il);
+    // total number of parameters in the model
+    uint64_t n_elements() const;
 
-// used by llama_adapter_lora
-struct ggml_tensor * llama_model_get_tensor(const struct llama_model & model, const char * name);
+    void print_info() const;
 
-size_t llama_model_max_nodes(const llama_model & model);
+    ggml_backend_dev_t dev_layer(int il) const;
+    ggml_backend_dev_t dev_output() const;
 
-struct llama_model_loader;
+    ggml_backend_buffer_type_t select_buft(int il) const;
+
+    const struct ggml_tensor * get_tensor(const char * name) const;
 
-// TODO: become llama_model methods
-void llm_load_stats     (llama_model_loader & ml, llama_model & model);
-void llm_load_arch      (llama_model_loader & ml, llama_model & model);
-void llm_load_hparams   (llama_model_loader & ml, llama_model & model);
-void llm_load_vocab     (llama_model_loader & ml, llama_model & model);
-void llm_load_print_meta(llama_model_loader & ml, llama_model & model);
+private:
+    struct impl;
+    std::unique_ptr pimpl;
+};
+
+const char * llm_type_name(llm_type type);
diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp
index 104f90343a4..d4947a780c1 100644
--- a/examples/talk-llama/llama-quant.cpp
+++ b/examples/talk-llama/llama-quant.cpp
@@ -7,14 +7,12 @@
 #include 
 #include 
 #include 
+#include 
 #include 
 #include 
 #include 
 #include 
 
-// TODO: replace with ggml API call
-#define QK_K 256
-
 static void zeros(std::ofstream & file, size_t n) {
     char zero = 0;
     for (size_t i = 0; i < n; ++i) {
@@ -154,8 +152,10 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
         if (qs.params->output_tensor_type < GGML_TYPE_COUNT) {
             new_type = qs.params->output_tensor_type;
         } else {
-            int nx = tensor->ne[0];
-            if (arch == LLM_ARCH_FALCON || nx % QK_K != 0) {
+            const int64_t nx = tensor->ne[0];
+            const int64_t qk_k = ggml_blck_size(new_type);
+
+            if (arch == LLM_ARCH_FALCON || nx % qk_k != 0) {
                 new_type = GGML_TYPE_Q8_0;
             }
             else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
@@ -235,7 +235,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
         else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
                 use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_Q6_K;
         else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && qs.i_attention_wv < 4) new_type = GGML_TYPE_Q5_K;
-        if (qs.model.type == MODEL_70B) {
+        if (qs.model.type == LLM_TYPE_70B) {
             // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is
             // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with
             // nearly negligible increase in model size by quantizing this tensor with more bits:
@@ -367,20 +367,19 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
     //    if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K;
     //}
     bool convert_incompatible_tensor = false;
-    if (new_type == GGML_TYPE_Q2_K    || new_type == GGML_TYPE_Q3_K    || new_type == GGML_TYPE_Q4_K   ||
-        new_type == GGML_TYPE_Q5_K    || new_type == GGML_TYPE_Q6_K    || new_type == GGML_TYPE_IQ4_XS ||
-        new_type == GGML_TYPE_IQ2_XS  || new_type == GGML_TYPE_IQ2_XXS || new_type == GGML_TYPE_IQ2_S  ||
-        new_type == GGML_TYPE_IQ3_XXS || new_type == GGML_TYPE_IQ1_S   || new_type == GGML_TYPE_IQ3_S  ||
-        new_type == GGML_TYPE_IQ1_M) {
-        int nx = tensor->ne[0];
-        int ny = tensor->ne[1];
-        if (nx % QK_K != 0) {
-            LLAMA_LOG_WARN("\n\n%s : tensor cols %d x %d are not divisible by %d, required for %s", __func__, nx, ny, QK_K, ggml_type_name(new_type));
+    {
+        const int64_t nx = tensor->ne[0];
+        const int64_t ny = tensor->ne[1];
+        const int64_t qk_k = ggml_blck_size(new_type);
+
+        if (nx % qk_k != 0) {
+            LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type));
             convert_incompatible_tensor = true;
         } else {
             ++qs.n_k_quantized;
         }
     }
+
     if (convert_incompatible_tensor) {
         switch (new_type) {
             case GGML_TYPE_TQ1_0:
@@ -526,18 +525,20 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
         auto v = (std::vector*)params->kv_overrides;
         kv_overrides = v->data();
     }
+
     llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides);
     ml.init_mappings(false); // no prefetching
 
-    llama_model model;
-    llm_load_arch   (ml, model);
-    llm_load_hparams(ml, model);
-    llm_load_stats  (ml, model);
+    llama_model model(llama_model_default_params());
+
+    model.load_arch   (ml);
+    model.load_hparams(ml);
+    model.load_stats  (ml);
 
     struct quantize_state_impl qs(model, params);
 
     if (params->only_copy) {
-        ftype = model.ftype;
+        ftype = ml.ftype;
     }
     const std::unordered_map> * imatrix_data = nullptr;
     if (params->imatrix) {
@@ -621,7 +622,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
 
     qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer;
 
-    // sanity checks
+    // sanity checks for models that have attention layers
+    if (qs.n_attention_wv != 0)
     {
         const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
         // attention layers have a non-zero number of kv heads
@@ -759,6 +761,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
         quantize &= name.find("time_mix_w2.weight") == std::string::npos;
         quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos;
         quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos;
+        quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
 
         // do not quantize relative position bias (T5)
         quantize &= name.find("attn_rel_b.weight") == std::string::npos;
@@ -875,7 +878,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
 
         // update the gguf meta data as we go
         gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type);
-        gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data, new_size);
+        GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size);
+        gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data);
 
         // write tensor data + padding
         fout.write((const char *) new_data, new_size);
diff --git a/examples/talk-llama/llama-sampling.cpp b/examples/talk-llama/llama-sampling.cpp
index ef5a576ccf7..b3a12386e8a 100644
--- a/examples/talk-llama/llama-sampling.cpp
+++ b/examples/talk-llama/llama-sampling.cpp
@@ -371,7 +371,10 @@ void llama_sampler_free(struct llama_sampler * smpl) {
 llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
     const auto * logits = llama_get_logits_ith(ctx, idx);
 
-    const int n_vocab = llama_n_vocab(llama_get_model(ctx));
+    const llama_model * model = llama_get_model(ctx);
+    const llama_vocab * vocab = llama_model_get_vocab(model);
+
+    const int n_vocab = llama_vocab_n_tokens(vocab);
 
     // TODO: do not allocate each time
     std::vector cur;
@@ -1445,7 +1448,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
 static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
     const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
 
-    auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);
+    auto * result = llama_sampler_init_grammar(ctx->vocab, nullptr, nullptr);
 
     // copy the state
     {
@@ -1481,19 +1484,19 @@ static struct llama_sampler_i llama_sampler_grammar_i = {
     /* .free   = */ llama_sampler_grammar_free,
 };
 
-struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
+struct llama_sampler * llama_sampler_init_grammar(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
     auto * ctx = new llama_sampler_grammar;
 
     if (grammar_str != nullptr && grammar_str[0] != '\0') {
         *ctx = {
-            /* .vocab        = */ &vocab,
+            /* .vocab        = */ vocab,
             /* .grammar_str  = */ grammar_str,
             /* .grammar_root = */ grammar_root,
-            /* .grammar      = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
+            /* .grammar      = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root),
         };
     } else {
         *ctx = {
-            /* .vocab        = */ &vocab,
+            /* .vocab        = */ vocab,
             /* .grammar_str  = */ {},
             /* .grammar_root = */ {},
             /* .grammar      = */ nullptr,
@@ -1663,8 +1666,8 @@ struct llama_sampler_dry {
 
 // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
 static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap>& token_sequences, int max_tail_len = -1) {
-    for (llama_token token_id = 0; token_id < (llama_token)vocab.n_vocab; token_id++) {
-        std::string word = llama_detokenize(vocab, {token_id}, true);
+    for (llama_token token_id = 0; token_id < (llama_token) vocab.n_tokens(); token_id++) {
+        std::string word = vocab.detokenize({token_id}, true);
         if (word.find(str) != std::string::npos) {
             token_sequences.emplace(token_id, std::vector());
         } else {
@@ -1681,7 +1684,7 @@ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std
                     }
                 }
                 if (match) {
-                    std::vector tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false);
+                    std::vector tokenization = vocab.tokenize(str.substr(i), false, false);
                     if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
                         tokenization.resize(max_tail_len);
                     }
@@ -1937,7 +1940,7 @@ static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler
     llama_vocab dummy_vocab;
 
     // dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying
-    auto * result = llama_sampler_init_dry_impl(dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
+    auto * result = llama_sampler_init_dry(&dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
 
     // Copy the state, including the processed breakers
     {
@@ -1964,7 +1967,7 @@ static struct llama_sampler_i llama_sampler_dry_i = {
     /* .free   = */ llama_sampler_dry_free,
 };
 
-struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
+struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
     int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
     std::unordered_multimap> processed_breakers;
     const int MAX_CHAR_LEN = 40;
@@ -1991,7 +1994,7 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
                 sequence_break.resize(MAX_CHAR_LEN);
             }
 
-            get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
+            get_overlapping_token_sequences(*vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
         }
     }
 
@@ -2014,7 +2017,7 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
 // wrapper for test-sampling.cpp
 struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector>& seq_breakers) {
     llama_vocab dummy_vocab;
-    auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
+    auto * result = llama_sampler_init_dry(&dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
     auto * ctx = (llama_sampler_dry *) result->ctx;
 
     // Process the token-based sequence breakers
@@ -2153,7 +2156,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
     float p_eog_sum = 0.0f;
 
     for (size_t i = 0; i < cur_p->size; ++i) {
-        if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
+        if (ctx->vocab->is_eog(cur_p->data[i].id)) {
             p_eog_sum += cur_p->data[i].p;
         } else {
             p_txt_sum += cur_p->data[i].p;
@@ -2175,7 +2178,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
         float p_sum = 0.0f;
 
         for (size_t i = 0; i < size_org; ++i) {
-            if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
+            if (ctx->vocab->is_eog(cur_p->data[i].id)) {
                 p_sum += cur_p->data[i].p;
 
                 cur_p->data[cur_p->size++] = cur_p->data[i];
@@ -2203,17 +2206,17 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
                 continue;
             }
 
-            int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
+            int len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
             if (len0 < 0) {
                 ctx->buf0.resize(len0);
-                len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
+                len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
                 assert(len0 > 0);
             }
 
-            int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
+            int len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
             if (len1 < 0) {
                 ctx->buf1.resize(len1);
-                len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
+                len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
                 assert(len1 > 0);
             }
 
@@ -2248,7 +2251,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
     LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
 
     for (size_t i = 0; i < size_org; ++i) {
-        const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
+        const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id);
 
         if (cur_p->data[i].p < thold && !is_eog) {
             continue;
@@ -2269,7 +2272,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
     // if no non-EOG tokens are left -> reduce cur_p to single EOT token
     if (n_non_eog == 0) {
         cur_p->size = 1;
-        cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
+        cur_p->data[0].id = ctx->vocab->token_eot();
         cur_p->data[0].logit = 1.0f;
 
         return;
@@ -2291,7 +2294,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
     LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
 
     for (size_t i = 0; i < size_org; ++i) {
-        const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);
+        const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id);
 
         if (cur_p->data[i].p < thold && !is_eog) {
             continue;
@@ -2314,7 +2317,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
 
 static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
     const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
-    return llama_sampler_init_infill_impl(*ctx->vocab);
+    return llama_sampler_init_infill(ctx->vocab);
 }
 
 static void llama_sampler_infill_free(struct llama_sampler * smpl) {
@@ -2330,14 +2333,13 @@ static struct llama_sampler_i llama_sampler_infill_i = {
     /* .free   = */ llama_sampler_infill_free,
 };
 
-struct llama_sampler * llama_sampler_init_infill_impl(
-        const struct llama_vocab & vocab) {
+struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
     return new llama_sampler {
         /* .iface = */ &llama_sampler_infill_i,
         /* .ctx   = */ new llama_sampler_infill {
-            /* .vocab = */ &vocab,
-            /* .buf0 = */ std::vector(512),
-            /* .buf1 = */ std::vector(512),
+            /* .vocab = */ vocab,
+            /* .buf0  = */ std::vector(512),
+            /* .buf1  = */ std::vector(512),
         },
     };
 }
diff --git a/examples/talk-llama/llama-sampling.h b/examples/talk-llama/llama-sampling.h
index 919f6fdfcef..759dd7dcb70 100644
--- a/examples/talk-llama/llama-sampling.h
+++ b/examples/talk-llama/llama-sampling.h
@@ -2,7 +2,9 @@
 
 // TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ?
 
-#include "llama-grammar.h"
+#include "llama.h"
+
+#include 
 
 struct llama_vocab;
 struct llama_grammar;
@@ -21,24 +23,6 @@ struct llama_sampler_chain {
     mutable int32_t n_sample;
 };
 
-struct llama_sampler * llama_sampler_init_grammar_impl(
-        const struct llama_vocab & vocab,
-                      const char * grammar_str,
-                      const char * grammar_root);
-
-struct llama_sampler * llama_sampler_init_infill_impl(
-        const struct llama_vocab & vocab);
-
-struct llama_sampler * llama_sampler_init_dry_impl(
-        const struct llama_vocab &  vocab,
-                         int32_t    context_size,
-                           float    dry_multiplier,
-                           float    dry_base,
-                         int32_t    dry_allowed_length,
-                         int32_t    dry_penalty_last_n,
-                      const char ** seq_breakers,
-                          size_t    num_breakers);
-
 struct llama_sampler * llama_sampler_init_dry_testing(
                          int32_t   context_size,
                            float   dry_multiplier,
diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp
index a4c015484dd..96b74e93a51 100644
--- a/examples/talk-llama/llama-vocab.cpp
+++ b/examples/talk-llama/llama-vocab.cpp
@@ -1,6 +1,7 @@
 #include "llama-vocab.h"
 
 #include "llama-impl.h"
+#include "llama-model-loader.h"
 
 #include "unicode.h"
 
@@ -11,8 +12,10 @@
 #include 
 #include 
 #include 
+#include 
 #include 
-#include 
+#include 
+#include 
 
 //
 // helpers
@@ -62,96 +65,14 @@ struct naive_trie {
 };
 
 //
-// impl
+// tokenizers
 //
 
 struct llm_tokenizer {
-   llm_tokenizer() {}
-   virtual ~llm_tokenizer() = default;
+    llm_tokenizer() {}
+    virtual ~llm_tokenizer() = default;
 };
 
-llama_vocab::~llama_vocab() {
-    delete tokenizer;
-}
-
-int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
-    GGML_ASSERT(token_left.find(' ')   == std::string::npos);
-    GGML_ASSERT(token_left.find('\n')  == std::string::npos);
-    GGML_ASSERT(token_right.find(' ')  == std::string::npos);
-    GGML_ASSERT(token_right.find('\n') == std::string::npos);
-
-    auto it = bpe_ranks.find(std::make_pair(token_left, token_right));
-    if (it == bpe_ranks.end()) {
-        return -1;
-    }
-
-    return it->second;
-}
-
-static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) {
-    return vocab.type;
-}
-
-static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL;
-}
-
-static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN;
-}
-
-static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL;
-}
-
-static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE;
-}
-
-static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED;
-}
-
-static bool llama_is_unused_token(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNUSED;
-}
-
-static uint8_t llama_token_to_byte(const llama_vocab & vocab, llama_token id) {
-    GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
-    GGML_ASSERT(llama_is_byte_token(vocab, id));
-    const auto & token_data = vocab.id_to_token.at(id);
-    switch (llama_vocab_get_type(vocab)) {
-        case LLAMA_VOCAB_TYPE_SPM:
-        case LLAMA_VOCAB_TYPE_UGM: {
-            auto buf = token_data.text.substr(3, 2);
-            return strtol(buf.c_str(), NULL, 16);
-        }
-        case LLAMA_VOCAB_TYPE_BPE: {
-            GGML_ABORT("fatal error");
-            //return unicode_utf8_to_byte(token_data.text); // TODO: why is this here after GGML_ASSERT?
-        }
-        case LLAMA_VOCAB_TYPE_WPM: {
-            GGML_ABORT("fatal error");
-        }
-        default:
-            GGML_ABORT("fatal error");
-    }
-}
-
-static void llama_escape_whitespace(std::string & text) {
-    replace_all(text, " ", "\xe2\x96\x81");
-}
-
-static void llama_unescape_whitespace(std::string & word) {
-    replace_all(word, "\xe2\x96\x81", " ");
-}
-
 struct llm_symbol {
     using index = int;
     index prev;
@@ -183,14 +104,13 @@ struct llm_bigram_spm {
 };
 
 struct llm_tokenizer_spm : llm_tokenizer {
-    llm_tokenizer_spm(const llama_vocab & /*vocab*/) : llm_tokenizer() {}
+    llm_tokenizer_spm(const llama_vocab & /*vocab*/) {}
 };
 
 struct llm_tokenizer_spm_session {
     llm_tokenizer_spm_session(const llama_vocab & vocab) : vocab(vocab) {}
 
-    void tokenize(const std::string & text, std::vector & output) {
-
+    void tokenize(const std::string & text, std::vector & output) {
         // split string into utf8 chars
         int index = 0;
         size_t offs = 0;
@@ -249,13 +169,13 @@ struct llm_tokenizer_spm_session {
     }
 
 private:
-    void resegment(llm_symbol & symbol, std::vector & output) {
+    void resegment(llm_symbol & symbol, std::vector & output) {
         auto text = std::string(symbol.text, symbol.n);
-        auto token = vocab.token_to_id.find(text);
+        auto token = vocab.text_to_token(text);
 
         // Do we need to support is_unused?
-        if (token != vocab.token_to_id.end()) {
-            output.push_back((*token).second);
+        if (token != LLAMA_TOKEN_NULL) {
+            output.push_back(token);
             return;
         }
 
@@ -265,8 +185,8 @@ struct llm_tokenizer_spm_session {
             // output any symbols that did not form tokens as bytes.
             output.reserve(output.size() + symbol.n);
             for (int j = 0; j < (int)symbol.n; ++j) {
-                llama_vocab::id token_id = llama_byte_to_token_impl(vocab, symbol.text[j]);
-                output.push_back(token_id);
+                llama_token id = vocab.byte_to_token(symbol.text[j]);
+                output.push_back(id);
             }
             return;
         }
@@ -280,17 +200,17 @@ struct llm_tokenizer_spm_session {
             return;
         }
         const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
-        auto token = vocab.token_to_id.find(text);
+        auto token = vocab.text_to_token(text);
 
-        if (token == vocab.token_to_id.end()) {
+        if (token == LLAMA_TOKEN_NULL) {
             return;
         }
 
-        if (static_cast((*token).second) >= vocab.id_to_token.size()) {
+        if (static_cast(token) >= vocab.n_tokens()) {
             return;
         }
 
-        const auto & tok_data = vocab.id_to_token[(*token).second];
+        const auto & tok_data = vocab.get_token_data(token);
 
         llm_bigram_spm bigram;
         bigram.left  = left;
@@ -353,9 +273,9 @@ struct llm_bigram_bpe {
 };
 
 struct llm_tokenizer_bpe : llm_tokenizer {
-    llm_tokenizer_bpe(const llama_vocab & vocab) : llm_tokenizer() {
-        GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
-        switch (vocab.type_pre) {
+    llm_tokenizer_bpe(const llama_vocab & vocab) {
+        GGML_ASSERT(vocab.get_type() == LLAMA_VOCAB_TYPE_BPE);
+        switch (vocab.get_pre_type()) {
             case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
                 regex_exprs = {
                     // original regex from tokenizer.json
@@ -488,39 +408,38 @@ struct llm_tokenizer_bpe : llm_tokenizer {
 };
 
 struct llm_tokenizer_bpe_session {
-    llm_tokenizer_bpe_session(const llama_vocab & vocab) : vocab(vocab),
-        bpe_tokenizer(static_cast(vocab.tokenizer)) {}
+    llm_tokenizer_bpe_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : vocab(vocab), tokenizer(tokenizer) {}
 
-    static void append(const llama_vocab::id token_id, std::vector & output)  {
+    static void append(const llama_token token_id, std::vector & output)  {
         output.push_back(token_id);
     }
 
-    bool append_bos(std::vector & output) const {
-        if (vocab.tokenizer_add_bos) {
-            GGML_ASSERT(vocab.special_bos_id != LLAMA_TOKEN_NULL);
-            output.push_back(vocab.special_bos_id);
+    bool append_bos(std::vector & output) const {
+        if (vocab.get_add_bos()) {
+            GGML_ASSERT(vocab.token_bos() != LLAMA_TOKEN_NULL);
+            output.push_back(vocab.token_bos());
             return true;
         }
         return false;
     }
 
-    bool append_eos(std::vector & output) const {
-        if (vocab.tokenizer_add_eos) {
-            GGML_ASSERT(vocab.special_eos_id != LLAMA_TOKEN_NULL);
-            output.push_back(vocab.special_eos_id);
+    bool append_eos(std::vector & output) const {
+        if (vocab.get_add_eos()) {
+            GGML_ASSERT(vocab.token_eos() != LLAMA_TOKEN_NULL);
+            output.push_back(vocab.token_eos());
             return true;
         }
         return false;
     }
 
-    void check_double_bos_eos(const std::vector & output) const {
-        if (vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
+    void check_double_bos_eos(const std::vector & output) const {
+        if (vocab.get_add_bos() && output.size() >= 2 && output[1] == vocab.token_bos()) {
             LLAMA_LOG_WARN(
                 "%s: Added a BOS token to the prompt as specified by the model but the prompt "
                 "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
                 "Are you sure this is what you want?\n", __FUNCTION__);
         }
-        if (vocab.tokenizer_add_eos && output.size() >= 2 && *(output.end()-2) == vocab.special_eos_id) {
+        if (vocab.get_add_bos() && output.size() >= 2 && *(output.end()-2) == vocab.token_eos()) {
             LLAMA_LOG_WARN(
                 "%s: Added a EOS token to the prompt as specified by the model but the prompt "
                 "also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. "
@@ -528,9 +447,9 @@ struct llm_tokenizer_bpe_session {
         }
     }
 
-    void tokenize(const std::string & text, std::vector & output) {
+    void tokenize(const std::string & text, std::vector & output) {
         int final_prev_index = -1;
-        const auto word_collection = unicode_regex_split(text, bpe_tokenizer->regex_exprs);
+        const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs);
 
         symbols_final.clear();
 
@@ -541,7 +460,8 @@ struct llm_tokenizer_bpe_session {
             int index = 0;
             size_t offset = 0;
 
-            if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
+            //if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
+            if (vocab.get_ignore_merges() && vocab.text_to_token(word) != LLAMA_TOKEN_NULL) {
                 symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
                 offset = word.size();
             }
@@ -615,18 +535,18 @@ struct llm_tokenizer_bpe_session {
                 }
 
                 const std::string str = std::string(symbol.text, symbol.n);
-                const auto token = vocab.token_to_id.find(str);
+                const auto token = vocab.text_to_token(str);
 
-                if (token == vocab.token_to_id.end()) {
+                if (token == LLAMA_TOKEN_NULL) {
                     for (auto j = str.begin(); j != str.end(); ++j) {
                         std::string byte_str(1, *j);
-                        auto token_multibyte = vocab.token_to_id.find(byte_str);
-                        if (token_multibyte != vocab.token_to_id.end()) {
-                            output.push_back(token_multibyte->second);
+                        auto token_multibyte = vocab.text_to_token(byte_str);
+                        if (token_multibyte != LLAMA_TOKEN_NULL) {
+                            output.push_back(token_multibyte);
                         }
                     }
                 } else {
-                    output.push_back((*token).second);
+                    output.push_back(token);
                 }
             }
         }
@@ -660,7 +580,7 @@ struct llm_tokenizer_bpe_session {
     }
 
     const llama_vocab & vocab;
-    const llm_tokenizer_bpe * bpe_tokenizer;
+    const llm_tokenizer_bpe & tokenizer;
 
     std::vector symbols;
     std::vector symbols_final;
@@ -672,14 +592,13 @@ struct llm_tokenizer_bpe_session {
 //
 
 struct llm_tokenizer_wpm : llm_tokenizer {
-    llm_tokenizer_wpm(const llama_vocab & /*vocab*/) : llm_tokenizer() {}
+    llm_tokenizer_wpm(const llama_vocab & /*vocab*/) {}
 };
 
 struct llm_tokenizer_wpm_session {
     llm_tokenizer_wpm_session(const llama_vocab & vocab) : vocab(vocab) {}
 
-    void tokenize(const std::string & text, std::vector & output) {
-        const auto & token_map = vocab.token_to_id;
+    void tokenize(const std::string & text, std::vector & output) {
         // normalize and split by whitespace
         std::vector words = preprocess(text);
         // bos token prepended already
@@ -702,10 +621,10 @@ struct llm_tokenizer_wpm_session {
             for (int i = 0; i < n; ++i) {
                 // loop through possible match length
                 bool match = false;
-                for (int j = std::min(n, i + vocab.max_token_len + 1); j > i; j--) {
-                    auto it = token_map.find(word1.substr(i, j - i));
-                    if (it != token_map.end()) {
-                        output.push_back(it->second);
+                for (int j = std::min(n, i + vocab.max_token_len() + 1); j > i; j--) {
+                    auto id = vocab.text_to_token(word1.substr(i, j - i));
+                    if (id != LLAMA_TOKEN_NULL) {
+                        output.push_back(id);
                         match = true;
                         i = j - 1;
                         break;
@@ -720,7 +639,7 @@ struct llm_tokenizer_wpm_session {
 
             // we didn't find any matches for this word
             if (current_tokens == output.size()) {
-                output.push_back(vocab.special_unk_id);
+                output.push_back(vocab.token_unk());
             }
         }
     }
@@ -789,45 +708,45 @@ struct llm_tokenizer_wpm_session {
 //
 
 struct llm_tokenizer_ugm : llm_tokenizer {
-    llm_tokenizer_ugm(const llama_vocab & vocab) : llm_tokenizer() {
-        if (vocab.precompiled_charsmap.size() > 0) {
+    llm_tokenizer_ugm(const llama_vocab & vocab, const std::vector & precompiled_charsmap) {
+        if (precompiled_charsmap.size() > 0) {
             size_t charsmap_offset = 0;
 
             // First four bytes of precompiled_charsmap contains length of binary
             // blob containing XOR-compressed compact double array (XCDA) entries
-            uint32_t xcda_blob_size = *(const uint32_t *) &vocab.precompiled_charsmap[0];
+            uint32_t xcda_blob_size = *(const uint32_t *) &precompiled_charsmap[0];
             charsmap_offset += sizeof(xcda_blob_size);
-            if (xcda_blob_size + charsmap_offset >= vocab.precompiled_charsmap.size()) {
+            if (xcda_blob_size + charsmap_offset >= precompiled_charsmap.size()) {
                 throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
             }
 
             // Next xcda_blob_size bytes contain entries of XOR-compressed compact
             // double array (XCDA). Each entry is bit-packed into a 32-bit integer.
-            xcda_array = (const uint32_t *) &vocab.precompiled_charsmap[charsmap_offset];
+            xcda_array = (const uint32_t *) &precompiled_charsmap[charsmap_offset];
             xcda_array_size = xcda_blob_size / sizeof(uint32_t);
             charsmap_offset += xcda_blob_size;
 
             // Remaining bytes of precompiled charsmap contain null-terminated
             // replacement strings for prefixes matched by the XCDA.
-            prefix_replacements = &vocab.precompiled_charsmap[charsmap_offset];
-            prefix_replacements_size = vocab.precompiled_charsmap.size() - charsmap_offset;
+            prefix_replacements = &precompiled_charsmap[charsmap_offset];
+            prefix_replacements_size = precompiled_charsmap.size() - charsmap_offset;
         }
 
-        for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
-            const auto &token_data = vocab.id_to_token[id];
+        for (uint32_t id = 0; id < vocab.n_tokens(); ++id) {
+            const auto & token_data = vocab.get_token_data(id);
 
-            if (llama_is_normal_token(vocab, id)) {
+            if (vocab.is_normal(id)) {
                 min_score = std::min(min_score, token_data.score);
                 max_score = std::max(max_score, token_data.score);
             }
 
-            if (llama_is_normal_token(vocab, id) ||
-                llama_is_user_defined_token(vocab, id) ||
-                llama_is_unused_token(vocab, id)) {
+            if (vocab.is_normal(id) ||
+                vocab.is_user_defined(id) ||
+                vocab.is_unused(id)) {
                 token_matcher.insert(token_data.text.data(), token_data.text.size(), id);
             }
 
-            if (llama_is_user_defined_token(vocab, id)) {
+            if (vocab.is_user_defined(id)) {
                 user_defined_token_matcher.insert(token_data.text.data(), token_data.text.size());
             }
         }
@@ -856,8 +775,7 @@ struct llm_tokenizer_ugm : llm_tokenizer {
 };
 
 struct llm_tokenizer_ugm_session {
-    llm_tokenizer_ugm_session(const llama_vocab & vocab) : vocab(vocab),
-        ugm_tokenizer(static_cast(vocab.tokenizer)) {}
+    llm_tokenizer_ugm_session(const llama_vocab & vocab, const llm_tokenizer_ugm & tokenizer) : vocab(vocab), tokenizer(tokenizer) {}
 
     /* This implementation is based on SentencePiece optimized Viterbi algorithm for
      * unigram language models. The general idea is to:
@@ -872,7 +790,7 @@ struct llm_tokenizer_ugm_session {
      * After processing the whole sequence we backtrack from the end to get
      * the best tokenization.
     */
-    void tokenize(const std::string & text, std::vector & output) {
+    void tokenize(const std::string & text, std::vector & output) {
         // get current size of output (for reversal later)
         size_t output_size = output.size();
 
@@ -885,9 +803,9 @@ struct llm_tokenizer_ugm_session {
         }
 
         // initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores
-        std::vector tokenization_results(input_len + 1, {vocab.special_unk_id, 0, -FLT_MAX});
+        std::vector tokenization_results(input_len + 1, {vocab.token_unk(), 0, -FLT_MAX});
         // at the beginning tokenization score is zero
-        tokenization_results[0] = { vocab.special_unk_id, 0, 0 };
+        tokenization_results[0] = { vocab.token_unk(), 0, 0 };
 
         for (size_t input_offset = 0; input_offset < input_len;) {
             size_t prefix_offset = input_offset;
@@ -897,7 +815,7 @@ struct llm_tokenizer_ugm_session {
             // traverse the token matcher trie to find a matching token
             bool single_codepoint_token_found = false;
             const struct best_tokenization & current_best = tokenization_results[input_offset];
-            const struct naive_trie * node = ugm_tokenizer->token_matcher.traverse(normalized[prefix_offset++]);
+            const struct naive_trie * node = tokenizer.token_matcher.traverse(normalized[prefix_offset++]);
 
             while (prefix_offset <= input_len && node != NULL) {
                 // check if we found valid token in prefix
@@ -907,13 +825,13 @@ struct llm_tokenizer_ugm_session {
                         single_codepoint_token_found = true;
                     }
                     llama_token token_id = node->value;
-                    const auto & token_data = vocab.id_to_token[token_id];
+                    const auto & token_data = vocab.get_token_data(token_id);
 
                     // we set the user-defined token scores to 0 to make them more likely to be selected
                     // (normal token scores are log probabilities, so they are negative)
                     // score type is double here to make tokenization results exactly
                     // the same as in the HF tokenizer using SentencePiece
-                    const double token_score = llama_is_user_defined_token(vocab, token_id) ? 0.0 : token_data.score;
+                    const double token_score = vocab.is_user_defined(token_id) ? 0.0 : token_data.score;
                     const double challenger_score = current_best.score_sum + token_score;
                     struct best_tokenization & current_champ = tokenization_results[prefix_offset];
                     if (challenger_score > current_champ.score_sum) {
@@ -927,11 +845,11 @@ struct llm_tokenizer_ugm_session {
             // if we didn't find a valid token corresponding to the whole UTF code point
             // then use unknown token as the tokenization of this UTF code point
             if (!single_codepoint_token_found) {
-                const double challenger_score = current_best.score_sum + ugm_tokenizer->unknown_token_score;
+                const double challenger_score = current_best.score_sum + tokenizer.unknown_token_score;
                 prefix_offset = input_offset + n_utf8_code_units;
                 struct best_tokenization & current_champ = tokenization_results[prefix_offset];
                 if (challenger_score > current_champ.score_sum) {
-                    struct best_tokenization challenger = { vocab.special_unk_id, input_offset, (float) challenger_score };
+                    struct best_tokenization challenger = { vocab.token_unk(), input_offset, (float) challenger_score };
                     current_champ = challenger;
                 }
             }
@@ -944,7 +862,7 @@ struct llm_tokenizer_ugm_session {
         // merge sequences of consecutive unknown tokens into single unknown tokens
         bool is_prev_unknown = false;
         for (struct best_tokenization & tokenization = tokenization_results[input_len]; ; tokenization = tokenization_results[tokenization.input_offset]) {
-            bool is_unknown = tokenization.token_id == vocab.special_unk_id;
+            bool is_unknown = tokenization.token_id == vocab.token_unk();
             if (!(is_prev_unknown && is_unknown)) {
                 output.push_back(tokenization.token_id);
             }
@@ -971,11 +889,11 @@ struct llm_tokenizer_ugm_session {
         normalized->clear();
         normalized->reserve(input.size() * 3);
 
-        const std::string space = vocab.tokenizer_escape_whitespaces ? ugm_tokenizer->escaped_space : " ";
+        const std::string space = vocab.get_escape_whitespaces() ? tokenizer.escaped_space : " ";
 
-        bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
-        bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
-        bool shall_merge_spaces = vocab.tokenizer_remove_extra_whitespaces;
+        const bool shall_prepend_space = !vocab.get_treat_whitespace_as_suffix() && vocab.get_add_space_prefix();
+        const bool shall_append_space  =  vocab.get_treat_whitespace_as_suffix() && vocab.get_add_space_prefix();
+        const bool shall_merge_spaces  =  vocab.get_remove_extra_whitespaces();
 
         bool is_space_prepended = false;
         bool processing_non_ws = false;
@@ -1067,7 +985,7 @@ struct llm_tokenizer_ugm_session {
 
         // if input prefix matches some user-defined token return this token as normalization result
         auto user_defined_token_match =
-           ugm_tokenizer->user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
+           tokenizer.user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
         if (user_defined_token_match.second > 0) {
             return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second };
         }
@@ -1075,8 +993,8 @@ struct llm_tokenizer_ugm_session {
         size_t longest_prefix_length = 0;
         size_t longest_prefix_offset = 0;
 
-        if (ugm_tokenizer->xcda_array_size > 0) {
-            struct xcda_array_view xcda_view(ugm_tokenizer->xcda_array, ugm_tokenizer->xcda_array_size);
+        if (tokenizer.xcda_array_size > 0) {
+            struct xcda_array_view xcda_view(tokenizer.xcda_array, tokenizer.xcda_array_size);
 
             // Find the longest normalized sequence matching the input prefix by walking
             // the XOR-compressed compact double array (XCDA) starting from the root node
@@ -1112,10 +1030,10 @@ struct llm_tokenizer_ugm_session {
 
         if (longest_prefix_length > 0) {
             // we have a match, so return the replacement sequence
-            if (longest_prefix_offset >= ugm_tokenizer->prefix_replacements_size) {
+            if (longest_prefix_offset >= tokenizer.prefix_replacements_size) {
                 throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
             }
-            const char * prefix_replacement = &(ugm_tokenizer->prefix_replacements)[longest_prefix_offset];
+            const char * prefix_replacement = &(tokenizer.prefix_replacements)[longest_prefix_offset];
             return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length };
         }
 
@@ -1132,7 +1050,7 @@ struct llm_tokenizer_ugm_session {
     }
 
     const llama_vocab & vocab;
-    const llm_tokenizer_ugm * ugm_tokenizer;
+    const llm_tokenizer_ugm & tokenizer;
 };
 
 //
@@ -1194,15 +1112,15 @@ static std::vector llama_unescape_rwkv_token(const std::string & escape
 }
 
 struct llm_tokenizer_rwkv : llm_tokenizer {
-    llm_tokenizer_rwkv(const llama_vocab & vocab) : llm_tokenizer() {
+    llm_tokenizer_rwkv(const llama_vocab & vocab) {
         // RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
         // For now, we decode the vocab here into the lookup we'll use for tokenization.
 
         // build trie
-        for (unsigned int id = 0; id < vocab.id_to_token.size(); ++id) {
-            const auto & token = vocab.id_to_token[id];
-            const auto data = llama_unescape_rwkv_token(token.text);
-            token_matcher.insert((const char *) data.data(), data.size(), id);
+        for (uint32_t id = 0; id < vocab.n_tokens(); ++id) {
+            const auto & data = vocab.get_token_data(id);
+            const auto text = llama_unescape_rwkv_token(data.text);
+            token_matcher.insert((const char *) text.data(), text.size(), id);
         }
     }
 
@@ -1210,16 +1128,15 @@ struct llm_tokenizer_rwkv : llm_tokenizer {
 };
 
 struct llm_tokenizer_rwkv_session {
-    llm_tokenizer_rwkv_session(const llama_vocab & vocab) : vocab(vocab),
-        rwkv_tokenizer(static_cast(*vocab.tokenizer)) {}
+    llm_tokenizer_rwkv_session(const llama_vocab & vocab, const llm_tokenizer_rwkv & tokenizer) : vocab(vocab), tokenizer(tokenizer) {}
 
-    void tokenize(const std::string & text, std::vector & output) {
+    void tokenize(const std::string & text, std::vector & output) {
         uint32_t position = 0;
         while (position < text.size()) {
-            const struct naive_trie * node = rwkv_tokenizer.token_matcher.traverse(text[position]);
+            const struct naive_trie * node = tokenizer.token_matcher.traverse(text[position]);
             if (node == NULL) {
                 // no matching token found, add unknown token
-                output.push_back(vocab.special_unk_id);
+                output.push_back(vocab.token_unk());
                 position += 1;
                 continue;
             }
@@ -1243,33 +1160,11 @@ struct llm_tokenizer_rwkv_session {
 
 private:
     const llama_vocab & vocab;
-    const llm_tokenizer_rwkv & rwkv_tokenizer;
+    const llm_tokenizer_rwkv & tokenizer;
 };
 
-void llama_vocab::init_tokenizer() {
-    switch (type) {
-        case LLAMA_VOCAB_TYPE_SPM:
-            tokenizer = new llm_tokenizer_spm(*this);
-            break;
-        case LLAMA_VOCAB_TYPE_BPE:
-            tokenizer = new llm_tokenizer_bpe(*this);
-            break;
-        case LLAMA_VOCAB_TYPE_WPM:
-            tokenizer = new llm_tokenizer_wpm(*this);
-            break;
-        case LLAMA_VOCAB_TYPE_UGM:
-            tokenizer = new llm_tokenizer_ugm(*this);
-            break;
-        case LLAMA_VOCAB_TYPE_RWKV:
-            tokenizer = new llm_tokenizer_rwkv(*this);
-            break;
-        default:
-            GGML_ABORT("unsupported vocab type");
-    }
-}
-
 //
-// (de-) tokenize
+// impl
 //
 
 typedef enum FRAGMENT_BUFFER_VARIANT_TYPE {
@@ -1278,7 +1173,7 @@ typedef enum FRAGMENT_BUFFER_VARIANT_TYPE {
 } FRAGMENT_BUFFER_VARIANT_TYPE;
 
 struct fragment_buffer_variant {
-    fragment_buffer_variant(llama_vocab::id _token)
+    fragment_buffer_variant(llama_token _token)
     :
         type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN),
         token(_token),
@@ -1289,7 +1184,7 @@ struct fragment_buffer_variant {
     fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length)
     :
         type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT),
-        token((llama_vocab::id) - 1),
+        token((llama_token) - 1),
         raw_text(_raw_text),
         offset(_offset),
         length(_length){
@@ -1299,451 +1194,1094 @@ struct fragment_buffer_variant {
         }
 
     const FRAGMENT_BUFFER_VARIANT_TYPE type;
-    const llama_vocab::id token;
+    const llama_token token;
     const std::string _dummy;
     const std::string & raw_text;
     const uint64_t offset;
     const uint64_t length;
 };
 
-// #define PRETOKENIZERDEBUG
+struct llama_vocab::impl {
+    uint32_t n_token_types = 0; // for BERT-style token types
 
-static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer, bool parse_special) {
-    // for each special token
-    for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
-        const auto & data = vocab.id_to_token[special_id];
-        const auto & special_token = data.text;
+    enum llama_vocab_type     type     = LLAMA_VOCAB_TYPE_SPM;
+    enum llama_vocab_pre_type pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
 
-        if (!parse_special && (data.attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_UNKNOWN))) {
-            // Ignore control and unknown tokens when parse_special == false
-            continue;
-            // User-defined tokens are still pre-tokenized before everything else
-            // ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726
-            // This is mostly relevant for neox-style tokenizers (mpt, olmo, stablelm, etc.)
-        }
+    int max_token_len = 0; // used for optimizing longest token search
 
-        // for each text fragment
-        std::forward_list::iterator it = buffer.begin();
-        while (it != buffer.end()) {
-            auto & fragment = (*it);
+    // default LLaMA special tokens
+    // TODO: should we set all of these to LLAMA_TOKEN_NULL?
+    llama_token special_bos_id  = 1;
+    llama_token special_eos_id  = 2;
+    llama_token special_eot_id  = LLAMA_TOKEN_NULL;
+    llama_token special_eom_id  = LLAMA_TOKEN_NULL;
+    llama_token special_unk_id  = 0;
+    llama_token special_sep_id  = LLAMA_TOKEN_NULL;
+    llama_token special_pad_id  = LLAMA_TOKEN_NULL;
+    llama_token special_mask_id = LLAMA_TOKEN_NULL;
 
-            // if a fragment is text ( not yet processed )
-            if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                const auto & raw_text = fragment.raw_text;
+    llama_token linefeed_id = 13;
 
-                auto raw_text_base_offset = fragment.offset;
-                auto raw_text_base_length = fragment.length;
+    // fim tokens
+    llama_token special_fim_pre_id = LLAMA_TOKEN_NULL;
+    llama_token special_fim_suf_id = LLAMA_TOKEN_NULL;
+    llama_token special_fim_mid_id = LLAMA_TOKEN_NULL;
+    llama_token special_fim_pad_id = LLAMA_TOKEN_NULL;
+    llama_token special_fim_rep_id = LLAMA_TOKEN_NULL; // repo
+    llama_token special_fim_sep_id = LLAMA_TOKEN_NULL; // file separator
 
-                // loop over the text
-                while (true) {
-                    // find the first occurrence of a given special token in this fragment
-                    //  passing offset argument only limit the "search area" but match coordinates
-                    //  are still relative to the source full raw_text
-                    auto match = raw_text.find(special_token, raw_text_base_offset);
+    // tokenizer flags
+    bool add_space_prefix           = false;
+    bool add_bos                    = false;
+    bool add_eos                    = false;
+    bool ignore_merges              = false;
+    bool clean_spaces               = false;  // clean_up_tokenization_spaces
+    bool remove_extra_whitespaces   = false;
+    bool escape_whitespaces         = true;
+    bool treat_whitespace_as_suffix = false;
 
-                    // no occurrences found, stop processing this fragment for a given special token
-                    if (match == std::string::npos) break;
+    std::unordered_map token_to_id;
+    std::vector                      id_to_token;
 
-                    // check if match is within bounds of offset <-> length
-                    if (match + special_token.length() > raw_text_base_offset + raw_text_base_length) break;
+    std::vector cache_special_tokens;
+    std::vector cache_token_to_piece; // llama_token_to_piece(special = true);
 
-#ifdef PRETOKENIZERDEBUG
-                    LLAMA_LOG_WARN("FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
-#endif
-                    auto source = std::distance(buffer.begin(), it);
+    std::map, int> bpe_ranks;
 
-                    // if match is further than base offset
-                    //  then we have some text to the left of it
-                    if (match > raw_text_base_offset) {
-                        // left
-                        const int64_t left_reminder_offset = raw_text_base_offset + 0;
-                        int64_t left_reminder_length = match - raw_text_base_offset;
+    // set of all tokens that cause "end of generation"
+    std::set special_eog_ids;
 
-                        if (data.attr & LLAMA_TOKEN_ATTR_LSTRIP) {
-                            while (left_reminder_length > 0 && isspace(raw_text[left_reminder_offset + left_reminder_length - 1])) {
-                                left_reminder_length--;
-                            }
-                        }
+    std::unique_ptr tokenizer;
 
-                        if (left_reminder_length > 0) {
-                            buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);
-                            it++;
-                        }
+    std::vector precompiled_charsmap;
 
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
-#endif
-                    }
+    impl(const llama_vocab & vocab) : vocab(vocab) {
+    }
 
-                    // special token
-                    buffer.emplace_after(it, special_id);
-                    it++;
+    ~impl() = default;
 
-                    // right
-                    if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) {
-                        int64_t right_reminder_offset = match + special_token.length();
-                        int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
+    void load(llama_model_loader & ml, const LLM_KV & kv);
 
-                        if (data.attr & LLAMA_TOKEN_ATTR_RSTRIP) {
-                            while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) {
-                                right_reminder_offset++;
-                                right_reminder_length--;
-                            }
-                        }
+    enum llama_vocab_type get_type() const;
 
-                        if (right_reminder_length > 0) {
-                            buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);
-                            it++;
-                        }
+    std::string type_name() const;
 
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());
-#endif
+    bool is_normal      (llama_token id) const;
+    bool is_unknown     (llama_token id) const;
+    bool is_control     (llama_token id) const;
+    bool is_byte        (llama_token id) const;
+    bool is_user_defined(llama_token id) const;
+    bool is_unused      (llama_token id) const;
+    bool is_eog         (llama_token id) const;
 
-                        if (source == 0) {
-                            buffer.erase_after(buffer.before_begin());
-                        } else {
-                            buffer.erase_after(std::next(buffer.begin(), (source - 1)));
-                        }
+    uint8_t token_to_byte(llama_token id) const;
 
-                        // repeat for the right side
-                        raw_text_base_offset = right_reminder_offset;
-                        raw_text_base_length = right_reminder_length;
+    llama_token_attr token_get_attr(llama_token id) const;
 
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("RR: (%ld %ld) '%s'\n", raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
-#endif
-                    } else {
-                        if (source == 0) {
-                            buffer.erase_after(buffer.before_begin());
-                        } else {
-                            buffer.erase_after(std::next(buffer.begin(), (source - 1)));
-                        }
-                        break;
-                    }
-                }
-            }
-            it++;
-        }
-    }
-}
+    void init_tokenizer(enum llama_vocab_type type);
 
-std::vector llama_tokenize_internal(
-        const llama_vocab & vocab,
-        std::string raw_text,
-        bool add_special,
-        bool parse_special) {
-    GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
+    void tokenizer_st_partition(std::forward_list & buffer, bool parse_special) const;
 
-    std::vector output;
-    std::forward_list fragment_buffer;
+    std::string token_to_piece_for_cache(
+                  llama_token   token,
+                         bool   special) const;
 
-    if (!raw_text.empty()) {
-        fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
-        tokenizer_st_partition(vocab, fragment_buffer, parse_special);
-    }
 
-    switch (vocab.type) {
-        case LLAMA_VOCAB_TYPE_SPM:
-            {
-                // OG tokenizer behavior:
-                //
-                // tokenizer.encode('', add_special_tokens=True)  returns [1]
-                // tokenizer.encode('', add_special_tokens=False) returns []
+    std::vector tokenize(
+            const std::string & raw_text,
+                         bool   add_special,
+                         bool   parse_special = false) const;
 
-                bool is_prev_special = true;  // prefix with space if first token
+    int32_t tokenize(
+                   const char * text,
+                      int32_t   text_len,
+                  llama_token * tokens,
+                      int32_t   n_tokens_max,
+                         bool   add_special,
+                         bool   parse_special) const;
 
-                if (add_special && vocab.tokenizer_add_bos) {
-                    GGML_ASSERT(vocab.special_bos_id != LLAMA_TOKEN_NULL);
-                    output.push_back(vocab.special_bos_id);
-                    is_prev_special = true;
-                }
+    // does not write null-terminator to buf
+    int32_t token_to_piece(
+                  llama_token   token,
+                         char * buf,
+                      int32_t   length,
+                      int32_t   lstrip,
+                         bool   special) const;
 
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
+    // use cached data
+    const std::string & token_to_piece(llama_token token) const;
 
-                        // prefix with space if previous is special
-                        if (vocab.tokenizer_add_space_prefix && is_prev_special) {
-                            raw_text = " " + raw_text;
-                        }
+    int32_t detokenize(
+            const llama_token * tokens,
+                      int32_t   n_tokens,
+                         char * text,
+                      int32_t   text_len_max,
+                         bool   remove_special,
+                         bool   unparse_special) const;
 
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        llama_escape_whitespace(raw_text);
-                        llm_tokenizer_spm_session session(vocab);
-                        session.tokenize(raw_text, output);
-                        is_prev_special = false;
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
-                        is_prev_special = true;
-                    }
-                }
+    std::string detokenize(
+            const std::vector & tokens,
+                                      bool   special) const;
 
-                if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
-                    LLAMA_LOG_WARN(
-                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
-                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
-                        "Are you sure this is what you want?\n", __FUNCTION__);
-                }
+    void print_info() const;
 
-                if (add_special && vocab.tokenizer_add_eos) {
-                    GGML_ASSERT(vocab.special_eos_id != LLAMA_TOKEN_NULL);
-                    output.push_back(vocab.special_eos_id);
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_BPE:
-            {
-                llm_tokenizer_bpe_session session(vocab);
-                // it calls some other methods that are not exist in llm_tokenizer,
-                // here just cast it to bpe tokenizer object
-                if (add_special) {
-                    session.append_bos(output);
-                }
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
+private:
+    const llama_vocab & vocab;
+};
 
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        session.tokenize(raw_text, output);
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        session.append(fragment.token, output);
-                    }
-                }
+void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
+    struct gguf_context * ctx = ml.meta.get();
 
-                if (add_special) {
-                    session.append_eos(output);
-                    session.check_double_bos_eos(output);
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_WPM:
-            {
-                if (add_special) {
-                    GGML_ASSERT(vocab.special_cls_id != LLAMA_TOKEN_NULL);
-                    output.push_back(vocab.special_cls_id);
-                }
+    // determine vocab type
+    {
+        std::string tokenizer_model;
+        std::string tokenizer_pre;
+
+        ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model);
+        ml.get_key(LLM_KV_TOKENIZER_PRE,   tokenizer_pre, false);
+
+        ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, n_token_types, false);
+
+        if (tokenizer_model == "no_vocab" || tokenizer_model == "none") {
+            type = LLAMA_VOCAB_TYPE_NONE;
+
+            // default special tokens
+            special_bos_id  = LLAMA_TOKEN_NULL;
+            special_eos_id  = LLAMA_TOKEN_NULL;
+            special_unk_id  = LLAMA_TOKEN_NULL;
+            special_sep_id  = LLAMA_TOKEN_NULL;
+            special_pad_id  = LLAMA_TOKEN_NULL;
+            special_mask_id = LLAMA_TOKEN_NULL;
+            linefeed_id     = LLAMA_TOKEN_NULL;
+
+            // read vocab size from metadata
+            uint32_t n_tokens = 0;
+            if (!ml.get_key(LLM_KV_VOCAB_SIZE, n_tokens, false)) {
+                LLAMA_LOG_WARN("%s: there is no vocab_size in metadata\n", __func__);
+            }
 
-                llm_tokenizer_wpm_session session(vocab);
+            return;
+        }
 
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
+        if (tokenizer_model == "llama") {
+            type = LLAMA_VOCAB_TYPE_SPM;
+
+            // default special tokens
+            special_bos_id  = 1;
+            special_eos_id  = 2;
+            special_unk_id  = 0;
+            special_sep_id  = LLAMA_TOKEN_NULL;
+            special_pad_id  = LLAMA_TOKEN_NULL;
+            special_mask_id = LLAMA_TOKEN_NULL;
+        } else if (tokenizer_model == "bert") {
+            type = LLAMA_VOCAB_TYPE_WPM;
+
+            // default special tokens
+            special_bos_id  = 101;
+            special_eos_id  = LLAMA_TOKEN_NULL;
+            special_unk_id  = 100;
+            special_sep_id  = 102;
+            special_pad_id  = 0;
+            special_mask_id = 103;
+        } else if (tokenizer_model == "gpt2") {
+            type = LLAMA_VOCAB_TYPE_BPE;
+
+            // read bpe merges and populate bpe ranks
+            const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
+            if (merges_keyidx == -1) {
+                throw std::runtime_error("cannot find tokenizer merges in model file\n");
+            }
 
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        session.tokenize(raw_text, output);
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
-                    }
-                }
+            const int n_merges = gguf_get_arr_n(ctx, merges_keyidx);
+            for (int i = 0; i < n_merges; i++) {
+                const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
+                //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
 
-                if (add_special) {
-                    GGML_ASSERT(vocab.special_sep_id != LLAMA_TOKEN_NULL);
-                    output.push_back(vocab.special_sep_id);
-                }
-            } break;
-        case LLAMA_VOCAB_TYPE_UGM:
-            {
-                if (add_special && vocab.tokenizer_add_bos) {
-                    GGML_ASSERT(vocab.special_bos_id != LLAMA_TOKEN_NULL);
-                    output.push_back(vocab.special_bos_id);
-                }
-                llm_tokenizer_ugm_session session(vocab);
+                std::string first;
+                std::string second;
 
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
-                        session.tokenize(raw_text, output);
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
-                    }
+                const size_t pos = word.find(' ', 1);
+
+                if (pos != std::string::npos) {
+                    first  = word.substr(0, pos);
+                    second = word.substr(pos + 1);
                 }
 
-                if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
-                    LLAMA_LOG_WARN(
-                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
-                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
-                        "Are you sure this is what you want?\n", __FUNCTION__);
+                bpe_ranks.emplace(std::make_pair(first, second), i);
+            }
+
+            // default special tokens
+            special_bos_id  = 11;
+            special_eos_id  = 11;
+            special_unk_id  = LLAMA_TOKEN_NULL;
+            special_sep_id  = LLAMA_TOKEN_NULL;
+            special_pad_id  = LLAMA_TOKEN_NULL;
+            special_mask_id = LLAMA_TOKEN_NULL;
+        } else if (tokenizer_model == "t5") {
+            type = LLAMA_VOCAB_TYPE_UGM;
+
+            // default special tokens
+            special_bos_id  = LLAMA_TOKEN_NULL;
+            special_eos_id  = 1;
+            special_unk_id  = 2;
+            special_sep_id  = LLAMA_TOKEN_NULL;
+            special_pad_id  = 0;
+            special_mask_id = LLAMA_TOKEN_NULL;
+
+            const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str());
+            if (precompiled_charsmap_keyidx != -1) {
+                size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
+                const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
+                precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
+#ifdef IS_BIG_ENDIAN
+                // correct endiannes of data in precompiled_charsmap binary blob
+                uint32_t * xcda_blob_size = (uint32_t *) &precompiled_charsmap[0];
+                *xcda_blob_size = __builtin_bswap32(*xcda_blob_size);
+                assert(*xcda_blob_size + sizeof(uint32_t) < n_precompiled_charsmap);
+                size_t xcda_array_size = *xcda_blob_size / sizeof(uint32_t);
+                uint32_t * xcda_array = (uint32_t *) &precompiled_charsmap[sizeof(uint32_t)];
+                for (size_t i = 0; i < xcda_array_size; ++i) {
+                    xcda_array[i] = __builtin_bswap32(xcda_array[i]);
                 }
+#endif
+            }
+        } else if (tokenizer_model == "rwkv") {
+            type = LLAMA_VOCAB_TYPE_RWKV;
+
+            // default special tokens
+            special_bos_id = LLAMA_TOKEN_NULL;
+            special_eos_id = LLAMA_TOKEN_NULL;
+            special_unk_id = LLAMA_TOKEN_NULL;
+            special_sep_id = LLAMA_TOKEN_NULL;
+            special_pad_id = LLAMA_TOKEN_NULL;
+        } else {
+            throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
+        }
+
+        // for now, only BPE models have pre-tokenizers
+        if (type == LLAMA_VOCAB_TYPE_BPE) {
+            add_space_prefix = false;
+            clean_spaces = true;
+            if (tokenizer_pre.empty()) {
+                LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
+                LLAMA_LOG_WARN("%s:                                             \n", __func__);
+                LLAMA_LOG_WARN("%s: ************************************        \n", __func__);
+                LLAMA_LOG_WARN("%s: GENERATION QUALITY WILL BE DEGRADED!        \n", __func__);
+                LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL             \n", __func__);
+                LLAMA_LOG_WARN("%s: ************************************        \n", __func__);
+                LLAMA_LOG_WARN("%s:                                             \n", __func__);
+                pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+            } else if (tokenizer_pre == "default") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+            } else if (
+                    tokenizer_pre == "llama3"   ||
+                    tokenizer_pre == "llama-v3" ||
+                    tokenizer_pre == "llama-bpe"||
+                    tokenizer_pre == "falcon3") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
+                ignore_merges = true;
+                add_bos = true;
+            } else if (
+                    tokenizer_pre == "deepseek-llm") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM;
+                clean_spaces = false;
+            } else if (
+                    tokenizer_pre == "deepseek-coder") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER;
+                clean_spaces = false;
+            } else if (
+                    tokenizer_pre == "deepseek-v3") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM;
+                clean_spaces = false;
+            } else if (
+                    tokenizer_pre == "falcon") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_FALCON;
+            } else if (
+                    tokenizer_pre == "mpt") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_MPT;
+            } else if (
+                    tokenizer_pre == "starcoder") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_STARCODER;
+            } else if (
+                    tokenizer_pre == "gpt-2"   ||
+                    tokenizer_pre == "phi-2"   ||
+                    tokenizer_pre == "jina-es" ||
+                    tokenizer_pre == "jina-de" ||
+                    tokenizer_pre == "gigachat"   ||
+                    tokenizer_pre == "jina-v1-en" ||
+                    tokenizer_pre == "jina-v2-es" ||
+                    tokenizer_pre == "jina-v2-de" ||
+                    tokenizer_pre == "jina-v2-code" ||
+                    tokenizer_pre == "roberta-bpe") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
+            } else if (
+                    tokenizer_pre == "refact") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_REFACT;
+            } else if (
+                tokenizer_pre == "command-r") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_COMMAND_R;
+                clean_spaces = false;
+            } else if (
+                tokenizer_pre == "qwen2") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
+                clean_spaces = false;
+            } else if (
+                tokenizer_pre == "stablelm2") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_STABLELM2;
+            } else if (
+                tokenizer_pre == "olmo") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_OLMO;
+            } else if (
+                tokenizer_pre == "dbrx") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_DBRX;
+            } else if (
+                tokenizer_pre == "smaug-bpe") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_SMAUG;
+            } else if (
+                tokenizer_pre == "poro-chat") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_PORO;
+                clean_spaces = false;
+            } else if (
+                tokenizer_pre == "chatglm-bpe") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_CHATGLM4;
+                special_bos_id = LLAMA_TOKEN_NULL;
+            } else if (
+                tokenizer_pre == "viking") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_VIKING;
+                clean_spaces = false;
+            } else if (
+                tokenizer_pre == "jais") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS;
+            } else if (
+                tokenizer_pre == "tekken") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_TEKKEN;
+                clean_spaces = false;
+                ignore_merges = true;
+                add_bos = true;
+            } else if (
+                tokenizer_pre == "smollm") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_SMOLLM;
+                clean_spaces = false;
+            } else if (
+                tokenizer_pre == "codeshell") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_CODESHELL;
+            } else if (
+                tokenizer_pre == "bloom") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_BLOOM;
+            } else if (
+                tokenizer_pre == "gpt3-finnish") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH;
+            } else if (
+                tokenizer_pre == "exaone") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_EXAONE;
+            } else if (
+                tokenizer_pre == "chameleon") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_CHAMELEON;
+                add_bos = true;
+                clean_spaces = false;
+            } else if (
+                tokenizer_pre == "minerva-7b") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_MINERVA;
+            } else if (
+                tokenizer_pre == "megrez") {
+                pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
+            } else {
+                throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
+            }
+        } else if (type == LLAMA_VOCAB_TYPE_SPM) {
+            pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+            add_space_prefix = true;
+            clean_spaces = false;
+            add_bos = true;
+            add_eos = false;
+        } else if (type == LLAMA_VOCAB_TYPE_WPM) {
+            pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+            add_space_prefix = false;
+            clean_spaces = true;
+            add_bos = true;
+            add_eos = false;
+        } else if (type == LLAMA_VOCAB_TYPE_UGM) {
+            pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+            add_bos = false;
+            add_eos = true;
+        } else if (type == LLAMA_VOCAB_TYPE_RWKV) {
+            pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+            add_space_prefix = false;
+            clean_spaces = false;
+            add_bos = false;
+            add_eos = false;
+        } else {
+            pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+        }
+
+        ml.get_key(LLM_KV_TOKENIZER_ADD_PREFIX,      add_space_prefix,         false);
+        ml.get_key(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, remove_extra_whitespaces, false);
+    }
+
+    const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str());
+    if (token_idx == -1) {
+        throw std::runtime_error("cannot find tokenizer vocab in model file\n");
+    }
+
+    const float * scores = nullptr;
+    const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str());
+    if (score_idx != -1) {
+        scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
+    }
+
+    const int * toktypes = nullptr;
+    const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str());
+    if (toktype_idx != -1) {
+        toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
+    }
+
+    uint32_t n_tokens = gguf_get_arr_n(ctx, token_idx);
+    id_to_token.resize(n_tokens);
+
+    for (uint32_t i = 0; i < n_tokens; i++) {
+        std::string word = gguf_get_arr_str(ctx, token_idx, i);
+        if (word.empty()) {
+            LLAMA_LOG_WARN("%s: empty token at index %u\n", __func__, i);
+            word = "[EMPTY_" + std::to_string(i) + "]";
+        }
+
+        token_to_id[word] = i;
+        max_token_len = std::max(max_token_len, (int) word.size());
+
+        auto & token_data = id_to_token[i];
+        token_data.text  = std::move(word);
+        token_data.score = scores ? scores[i] : 0.0f;
+        token_data.attr  = LLAMA_TOKEN_ATTR_NORMAL;
+
+        if (toktypes) {  //TODO: remove, required until per token attributes are available from GGUF file
+            switch(toktypes[i]) {
+                case LLAMA_TOKEN_TYPE_UNKNOWN:      token_data.attr = LLAMA_TOKEN_ATTR_UNKNOWN;      break;
+                case LLAMA_TOKEN_TYPE_UNUSED:       token_data.attr = LLAMA_TOKEN_ATTR_UNUSED;       break;
+                case LLAMA_TOKEN_TYPE_NORMAL:       token_data.attr = LLAMA_TOKEN_ATTR_NORMAL;       break;
+                case LLAMA_TOKEN_TYPE_CONTROL:      token_data.attr = LLAMA_TOKEN_ATTR_CONTROL;      break;
+                case LLAMA_TOKEN_TYPE_USER_DEFINED: token_data.attr = LLAMA_TOKEN_ATTR_USER_DEFINED; break;
+                case LLAMA_TOKEN_TYPE_BYTE:         token_data.attr = LLAMA_TOKEN_ATTR_BYTE;         break;
+                case LLAMA_TOKEN_TYPE_UNDEFINED:    token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED;    break;
+                default:                            token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED;    break;
+            }
+        }
+    }
+    GGML_ASSERT(id_to_token.size() == token_to_id.size());
+
+    init_tokenizer(type);
+
+    // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
+    if (type == LLAMA_VOCAB_TYPE_SPM) {
+        try {
+            linefeed_id = vocab.byte_to_token('\n');
+        } catch (const std::exception & e) {
+            LLAMA_LOG_WARN("%s: SPM vocabulary, but newline token not found: %s! Using special_pad_id instead.", __func__, e.what());
+            linefeed_id = special_pad_id;
+        }
+    } else if (type == LLAMA_VOCAB_TYPE_WPM) {
+        linefeed_id = special_pad_id;
+    } else if (type == LLAMA_VOCAB_TYPE_RWKV) {
+        const std::vector ids = tokenize("\n", false);
+        GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
+        linefeed_id = ids[0];
+    } else {
+        const std::vector ids = tokenize("\xC4\x8A", false); // U+010A
+
+        //GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
+        if (ids.empty()) {
+            LLAMA_LOG_WARN("%s: model vocab missing newline token, using special_pad_id instead\n", __func__);
+            linefeed_id = special_pad_id;
+        } else {
+            linefeed_id = ids[0];
+        }
+    }
+
+    // special tokens
+    {
+        const std::vector> special_token_types = {
+            { LLM_KV_TOKENIZER_BOS_ID,     special_bos_id     },
+            { LLM_KV_TOKENIZER_EOS_ID,     special_eos_id     },
+            { LLM_KV_TOKENIZER_EOT_ID,     special_eot_id     },
+            { LLM_KV_TOKENIZER_EOM_ID,     special_eom_id     },
+            { LLM_KV_TOKENIZER_UNK_ID,     special_unk_id     },
+            { LLM_KV_TOKENIZER_SEP_ID,     special_sep_id     },
+            { LLM_KV_TOKENIZER_PAD_ID,     special_pad_id     },
+            { LLM_KV_TOKENIZER_MASK_ID,    special_mask_id    },
+            { LLM_KV_TOKENIZER_FIM_PRE_ID, special_fim_pre_id },
+            { LLM_KV_TOKENIZER_FIM_SUF_ID, special_fim_suf_id },
+            { LLM_KV_TOKENIZER_FIM_MID_ID, special_fim_mid_id },
+            { LLM_KV_TOKENIZER_FIM_PAD_ID, special_fim_pad_id },
+            { LLM_KV_TOKENIZER_FIM_REP_ID, special_fim_rep_id },
+            { LLM_KV_TOKENIZER_FIM_SEP_ID, special_fim_sep_id },
+
+            // deprecated
+            { LLM_KV_TOKENIZER_PREFIX_ID, special_fim_pre_id },
+            { LLM_KV_TOKENIZER_SUFFIX_ID, special_fim_suf_id },
+            { LLM_KV_TOKENIZER_MIDDLE_ID, special_fim_mid_id },
+        };
+
+        for (const auto & it : special_token_types) {
+            const std::string & key = kv(std::get<0>(it));
+            int32_t & id = std::get<1>(it);
+
+            uint32_t new_id;
+            if (!ml.get_key(std::get<0>(it), new_id, false)) {
+                continue;
+            }
+            if (new_id >= id_to_token.size()) {
+                LLAMA_LOG_WARN("%s: bad special token: '%s' = %u, using default id %d\n",
+                    __func__, key.c_str(), new_id, id);
+            } else {
+                id = new_id;
+            }
+        }
+
+        // Handle add_bos and add_eos
+        {
+            bool temp = true;
+
+            if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
+                add_bos = temp;
+            }
+            if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
+                add_eos = temp;
+            }
+        }
 
-                if (add_special && vocab.tokenizer_add_eos) {
-                    GGML_ASSERT(vocab.special_eos_id != LLAMA_TOKEN_NULL);
-                    output.push_back(vocab.special_eos_id);
+        // auto-detect special tokens by text
+        // TODO: convert scripts should provide these tokens through the KV metadata LLM_KV_TOKENIZER_...
+        //       for now, we apply this workaround to find the tokens based on their text
+
+        for (const auto & t : token_to_id) {
+            // find EOT token: "<|eot_id|>", "<|im_end|>", "", etc.
+            if (special_eot_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|eot_id|>"
+                        || t.first == "<|im_end|>"
+                        || t.first == "<|end|>"
+                        || t.first == ""
+                        || t.first == "<|endoftext|>"
+                        || t.first == ""
+                        || t.first == "<|end▁of▁sentence|>" // DeepSeek
+                   ) {
+                    special_eot_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
                 }
-            } break;
-        case LLAMA_VOCAB_TYPE_RWKV:
-            {
-                llm_tokenizer_rwkv_session session(vocab);
-                for (const auto & fragment : fragment_buffer) {
-                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
-                        auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
+            }
 
-#ifdef PRETOKENIZERDEBUG
-                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
-#endif
+            // find EOM token: "<|eom_id|>"
+            if (special_eom_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|eom_id|>"
+                        ) {
+                    special_eom_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
 
-                        session.tokenize(raw_text, output);
-                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
+            // find FIM_PRE token: "<|fim_prefix|>", "", "
", etc.
+            if (special_fim_pre_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_prefix|>"  // Qwen
+                        || t.first == ""
+                        || t.first == "<|fim▁begin|>" // DeepSeek
+                        || t.first == "
"
+                        ) {
+                    special_fim_pre_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
                     }
                 }
-            } break;
-        case LLAMA_VOCAB_TYPE_NONE:
-            GGML_ABORT("fatal error");
+            }
+
+            // find FIM_SUF token: "<|fim_suffix|>", "", "", etc.
+            if (special_fim_suf_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_suffix|>" // Qwen
+                        || t.first == ""
+                        || t.first == "<|fim▁hole|>" // DeepSeek
+                        || t.first == ""
+                        ) {
+                    special_fim_suf_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_MID token: "<|fim_middle|>", "", "", etc.
+            if (special_fim_mid_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_middle|>" // Qwen
+                        || t.first == ""
+                        || t.first == "<|fim▁end|>"  // DeepSeek
+                        || t.first == ""
+                        ) {
+                    special_fim_mid_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_PAD token: "<|fim_pad|>", "", "", etc.
+            if (special_fim_pad_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_pad|>" // Qwen
+                        || t.first == ""
+                        || t.first == ""
+                        ) {
+                    special_fim_pad_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_REP token: "<|fim_repo|>", "", "", etc.
+            if (special_fim_rep_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_repo|>"  // Qwen
+                        || t.first == "<|repo_name|>"
+                        || t.first == ""
+                        || t.first == ""
+                        ) {
+                    special_fim_rep_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_SEP token: "<|file_sep|>"
+            if (special_fim_sep_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|file_sep|>" // Qwen
+                        ) {
+                    special_fim_sep_id = t.second;
+                    if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+        }
+
+        // maintain a list of tokens that cause end-of-generation
+        // this is currently determined based on the token text, which is obviously not ideal
+        // ref: https://github.com/ggerganov/llama.cpp/issues/9606
+        special_eog_ids.clear();
+
+        if (special_fim_pad_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_pad_id) == 0) {
+            special_eog_ids.insert(special_fim_pad_id);
+        }
+
+        if (special_fim_rep_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_rep_id) == 0) {
+            special_eog_ids.insert(special_fim_rep_id);
+        }
+
+        if (special_fim_sep_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_sep_id) == 0) {
+            special_eog_ids.insert(special_fim_sep_id);
+        }
+
+        for (const auto & t : token_to_id) {
+            if (false
+                    || t.first == "<|eot_id|>"
+                    || t.first == "<|im_end|>"
+                    || t.first == "<|end|>"
+                    || t.first == ""
+                    || t.first == "<|endoftext|>"
+                    || t.first == "<|eom_id|>"
+                    || t.first == ""
+               ) {
+                special_eog_ids.insert(t.second);
+                if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                    LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                            __func__, t.second, t.first.c_str());
+                    id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                }
+            } else {
+                // token is control, but not marked as EOG -> print a debug log
+                if (id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && special_eog_ids.count(t.second) == 0) {
+                    LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n",
+                            __func__, t.second, t.first.c_str());
+                }
+            }
+        }
+
+        // sanity checks
+        if (special_eos_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eos_id) == 0) {
+            special_eog_ids.insert(special_eos_id);
+            LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+
+        if (special_eot_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eot_id) == 0) {
+            special_eog_ids.insert(special_eot_id);
+            LLAMA_LOG_WARN("%s: special_eot_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
+
+        if (special_eom_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eom_id) == 0) {
+            special_eog_ids.insert(special_eom_id);
+            LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
+        }
     }
 
-    return output;
-}
+    // build special tokens cache
+    {
+        for (llama_token id = 0; id < (llama_token) n_tokens; ++id) {
+            if (id_to_token[id].attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN)) {
+                cache_special_tokens.push_back(id);
+            }
+        }
 
-llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch) {
-    GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
-    static const char * hex = "0123456789ABCDEF";
-    switch (llama_vocab_get_type(vocab)) {
-        case LLAMA_VOCAB_TYPE_SPM:
-        case LLAMA_VOCAB_TYPE_UGM: {
-            const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
-            auto token = vocab.token_to_id.find(buf);
-            if (token != vocab.token_to_id.end()) {
-                return (*token).second;
+        std::sort(cache_special_tokens.begin(), cache_special_tokens.end(),
+            [&] (const llama_token a, const llama_token b) {
+                return id_to_token[a].text.size() > id_to_token[b].text.size();
             }
-            // Try to fall back to just the byte as a string
-            const char buf2[2] = { (char)ch, 0 };
-            return vocab.token_to_id.at(buf2);
+        );
+
+        LLAMA_LOG_INFO("%s: special tokens cache size = %u\n", __func__, (uint32_t) cache_special_tokens.size());
+    }
+
+    // build token to piece cache
+    {
+        size_t size_cache = 0;
+
+        std::vector cache(n_tokens);
+
+        for (uint32_t id = 0; id < n_tokens; ++id) {
+            cache[id] = token_to_piece_for_cache(id, true);
+
+            size_cache += cache[id].size();
         }
-        case LLAMA_VOCAB_TYPE_WPM:
-        case LLAMA_VOCAB_TYPE_BPE: {
-            return vocab.token_to_id.at(unicode_byte_to_utf8(ch));
+
+        std::swap(cache_token_to_piece, cache);
+
+        LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0);
+    }
+
+    // Handle per token attributes
+    //NOTE: Each model customizes per token attributes.
+    //NOTE: Per token attributes are missing from the GGUF file.
+    //TODO: Extract attributes from GGUF file.
+    {
+        auto _contains_any = [] (const std::string & str, const std::vector & substrs) -> bool {
+            for (const auto & substr : substrs) {
+                if (str.find(substr) < std::string::npos) {
+                    return true;
+                }
+            }
+            return false;
+        };
+
+        auto _set_tokenid_attr = [&] (const llama_token id, llama_token_attr attr, bool value) {
+            uint32_t current = id_to_token.at(id).attr;
+            current = value ? (current | attr) : (current & ~attr);
+            id_to_token[id].attr = (llama_token_attr) current;
+        };
+
+        auto _set_token_attr = [&] (const std::string & token, llama_token_attr attr, bool value) {
+            _set_tokenid_attr(token_to_id.at(token), attr, value);
+        };
+
+        std::string model_name;
+        std::string tokenizer_pre;
+
+        ml.get_key(LLM_KV_GENERAL_NAME,  model_name,    false);
+        ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
+
+        // model name to lowercase
+        std::transform(model_name.begin(), model_name.end(), model_name.begin(),
+            [] (const std::string::value_type x) {
+                return std::tolower(x);
+            }
+        );
+
+        // set attributes by model/tokenizer name
+        if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) {
+            _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true);
+        } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
+            for (auto id : cache_special_tokens) {
+                _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
+            }
+            for (const auto * token : {""}) {
+                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true);
+            }
+            for (const auto * token : {"", "", "<|endoftext|>"}) {
+                _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
+            }
         }
-        default:
-            GGML_ABORT("fatal error");
     }
 }
 
-const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[token].text.c_str();
+enum llama_vocab_type llama_vocab::impl::get_type() const {
+    return type;
 }
 
-float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[token].score;
+std::string llama_vocab::impl::type_name() const{
+    switch (type) {
+        case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
+        case LLAMA_VOCAB_TYPE_SPM:  return "SPM";
+        case LLAMA_VOCAB_TYPE_BPE:  return "BPE";
+        case LLAMA_VOCAB_TYPE_WPM:  return "WPM";
+        case LLAMA_VOCAB_TYPE_UGM:  return "UGM";
+        case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
+        default:                    return "unknown";
+    }
 }
 
-llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token) {
-    GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
-    return vocab.id_to_token[token].attr;
+bool llama_vocab::impl::is_normal(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL;
 }
 
-bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token) {
-    return token != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(token) > 0;
+bool llama_vocab::impl::is_unknown(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN;
 }
 
-bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token) {
-    return llama_is_control_token(vocab, token);
+bool llama_vocab::impl::is_control(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL;
 }
 
-llama_token llama_token_bos_impl(const struct llama_vocab & vocab) {
-    return vocab.type != LLAMA_VOCAB_TYPE_WPM ? vocab.special_bos_id : vocab.special_cls_id;
+bool llama_vocab::impl::is_byte(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE;
 }
 
-llama_token llama_token_eos_impl(const struct llama_vocab & vocab) {
-    return vocab.special_eos_id;
+bool llama_vocab::impl::is_user_defined(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED;
 }
 
-llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
-    return vocab.special_eot_id;
+bool llama_vocab::impl::is_unused(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNUSED;
 }
 
-llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
-    return vocab.special_eom_id;
+bool llama_vocab::impl::is_eog(llama_token id) const {
+    return id != LLAMA_TOKEN_NULL && special_eog_ids.count(id) > 0;
 }
 
-llama_token llama_token_cls_impl(const struct llama_vocab & vocab) {
-    return vocab.special_cls_id;
+uint8_t llama_vocab::impl::token_to_byte(llama_token id) const {
+    GGML_ASSERT(get_type() != LLAMA_VOCAB_TYPE_NONE);
+    GGML_ASSERT(is_byte(id));
+    const auto & token_data = id_to_token.at(id);
+    switch (get_type()) {
+        case LLAMA_VOCAB_TYPE_SPM:
+        case LLAMA_VOCAB_TYPE_UGM: {
+            auto buf = token_data.text.substr(3, 2);
+            return strtol(buf.c_str(), NULL, 16);
+        }
+        case LLAMA_VOCAB_TYPE_BPE: {
+            GGML_ABORT("fatal error");
+        }
+        case LLAMA_VOCAB_TYPE_WPM: {
+            GGML_ABORT("fatal error");
+        }
+        default:
+            GGML_ABORT("fatal error");
+    }
 }
 
-llama_token llama_token_sep_impl(const struct llama_vocab & vocab) {
-    return vocab.special_sep_id;
+llama_token_attr llama_vocab::impl::token_get_attr(llama_token id) const {
+    GGML_ASSERT(type != LLAMA_VOCAB_TYPE_NONE);
+    return id_to_token.at(id).attr;
 }
 
-llama_token llama_token_nl_impl(const struct llama_vocab & vocab) {
-    return vocab.linefeed_id;
-}
+void llama_vocab::impl::init_tokenizer(enum llama_vocab_type type) {
+    LLAMA_LOG_DEBUG("%s: initializing tokenizer for type %d\n", __func__, type);
 
-llama_token llama_token_pad_impl(const struct llama_vocab & vocab) {
-    return vocab.special_pad_id;
+    switch (type) {
+        case LLAMA_VOCAB_TYPE_SPM:
+            tokenizer = std::make_unique(vocab);
+            break;
+        case LLAMA_VOCAB_TYPE_BPE:
+            tokenizer = std::make_unique(vocab);
+            break;
+        case LLAMA_VOCAB_TYPE_WPM:
+            tokenizer = std::make_unique(vocab);
+            break;
+        case LLAMA_VOCAB_TYPE_UGM:
+            tokenizer = std::make_unique(vocab, precompiled_charsmap);
+            break;
+        case LLAMA_VOCAB_TYPE_RWKV:
+            tokenizer = std::make_unique(vocab);
+            break;
+        default:
+            GGML_ABORT("unsupported vocab type");
+    }
 }
 
-bool llama_add_bos_token_impl(const struct llama_vocab & vocab) {
-    return vocab.tokenizer_add_bos;
-}
+//
+// (de-) tokenize
+//
 
-bool llama_add_eos_token_impl(const struct llama_vocab & vocab) {
-    return vocab.tokenizer_add_eos;
-}
+// #define PRETOKENIZERDEBUG
 
-llama_token llama_token_prefix_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_pre_id;
-}
+void llama_vocab::impl::tokenizer_st_partition(std::forward_list & buffer, bool parse_special) const {
+    // for each special token
+    for (const llama_token special_id : cache_special_tokens) {
+        const auto & data = vocab.get_token_data(special_id);
+        const auto & text = data.text;
 
-llama_token llama_token_middle_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_mid_id;
-}
+        if (!parse_special && (data.attr & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_UNKNOWN))) {
+            // Ignore control and unknown tokens when parse_special == false
+            continue;
+            // User-defined tokens are still pre-tokenized before everything else
+            // ref: https://github.com/huggingface/tokenizers/blob/fdd26ba9a3f0c133427aab0423888cbde91362d7/tokenizers/src/tokenizer/mod.rs#L726
+            // This is mostly relevant for neox-style tokenizers (mpt, olmo, stablelm, etc.)
+        }
 
-llama_token llama_token_suffix_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_suf_id;
-}
+        // for each text fragment
+        std::forward_list::iterator it = buffer.begin();
+        while (it != buffer.end()) {
+            auto & fragment = (*it);
 
-llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_pre_id;
-}
+            // if a fragment is text ( not yet processed )
+            if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                const auto & raw_text = fragment.raw_text;
 
-llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_suf_id;
-}
+                auto raw_text_base_offset = fragment.offset;
+                auto raw_text_base_length = fragment.length;
 
-llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_mid_id;
-}
+                // loop over the text
+                while (true) {
+                    // find the first occurrence of a given special token in this fragment
+                    //  passing offset argument only limit the "search area" but match coordinates
+                    //  are still relative to the source full raw_text
+                    auto match = raw_text.find(text, raw_text_base_offset);
 
-llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_pad_id;
-}
+                    // no occurrences found, stop processing this fragment for a given special token
+                    if (match == std::string::npos) break;
 
-llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_rep_id;
-}
+                    // check if match is within bounds of offset <-> length
+                    if (match + text.length() > raw_text_base_offset + raw_text_base_length) break;
 
-llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab) {
-    return vocab.special_fim_sep_id;
-}
+#ifdef PRETOKENIZERDEBUG
+                    LLAMA_LOG_WARN("FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
+#endif
+                    auto source = std::distance(buffer.begin(), it);
 
-int32_t llama_tokenize_impl(
-        const struct llama_vocab & vocab,
-                      const char * text,
-                         int32_t   text_len,
-                     llama_token * tokens,
-                         int32_t   n_tokens_max,
-                            bool   add_special,
-                            bool   parse_special) {
-    auto res = llama_tokenize_internal(vocab, std::string(text, text_len), add_special, parse_special);
-    if (n_tokens_max < (int) res.size()) {
-        // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
-        return -((int) res.size());
+                    // if match is further than base offset
+                    //  then we have some text to the left of it
+                    if (match > raw_text_base_offset) {
+                        // left
+                        const int64_t left_reminder_offset = raw_text_base_offset + 0;
+                        int64_t left_reminder_length = match - raw_text_base_offset;
+
+                        if (data.attr & LLAMA_TOKEN_ATTR_LSTRIP) {
+                            while (left_reminder_length > 0 && isspace(raw_text[left_reminder_offset + left_reminder_length - 1])) {
+                                left_reminder_length--;
+                            }
+                        }
+
+                        if (left_reminder_length > 0) {
+                            buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);
+                            it++;
+                        }
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
+#endif
+                    }
+
+                    // special token
+                    buffer.emplace_after(it, special_id);
+                    it++;
+
+                    // right
+                    if (match + text.length() < raw_text_base_offset + raw_text_base_length) {
+                        int64_t right_reminder_offset = match + text.length();
+                        int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + text.length());
+
+                        if (data.attr & LLAMA_TOKEN_ATTR_RSTRIP) {
+                            while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) {
+                                right_reminder_offset++;
+                                right_reminder_length--;
+                            }
+                        }
+
+                        if (right_reminder_length > 0) {
+                            buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);
+                            it++;
+                        }
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());
+#endif
+
+                        if (source == 0) {
+                            buffer.erase_after(buffer.before_begin());
+                        } else {
+                            buffer.erase_after(std::next(buffer.begin(), (source - 1)));
+                        }
+
+                        // repeat for the right side
+                        raw_text_base_offset = right_reminder_offset;
+                        raw_text_base_length = right_reminder_length;
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("RR: (%ld %ld) '%s'\n", raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str());
+#endif
+                    } else {
+                        if (source == 0) {
+                            buffer.erase_after(buffer.before_begin());
+                        } else {
+                            buffer.erase_after(std::next(buffer.begin(), (source - 1)));
+                        }
+                        break;
+                    }
+                }
+            }
+            it++;
+        }
     }
+}
 
-    for (size_t i = 0; i < res.size(); i++) {
-        tokens[i] = res[i];
+// NOTE: avoid ever using this except for building the token_to_piece caches
+std::string llama_vocab::impl::token_to_piece_for_cache(llama_token token, bool special) const {
+    std::string piece;
+    piece.resize(piece.capacity());  // using string internal cache
+    const int n_chars = vocab.token_to_piece(token, &piece[0], piece.size(), 0, special);
+    if (n_chars < 0) {
+        piece.resize(-n_chars);
+        int check = vocab.token_to_piece(token, &piece[0], piece.size(), 0, special);
+        GGML_ASSERT(check == -n_chars);
+    }
+    else {
+        piece.resize(n_chars);
     }
 
-    return res.size();
+    return piece;
+}
+
+static void llama_escape_whitespace(std::string & text) {
+    replace_all(text, " ", "\xe2\x96\x81");
+}
+
+static void llama_unescape_whitespace(std::string & word) {
+    replace_all(word, "\xe2\x96\x81", " ");
 }
 
 static std::string llama_decode_text(const std::string & text) {
@@ -1766,11 +2304,185 @@ static std::string llama_decode_text(const std::string & text) {
     return decoded_text;
 }
 
-// does not write null-terminator to buf
-int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) {
+std::vector llama_vocab::impl::tokenize(
+        const std::string & raw_text,
+        bool add_special,
+        bool parse_special) const {
+    GGML_ASSERT(tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
+
+    std::vector output;
+    std::forward_list fragment_buffer;
+
+    if (!raw_text.empty()) {
+        fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
+        tokenizer_st_partition(fragment_buffer, parse_special);
+    }
+
+    switch (get_type()) {
+        case LLAMA_VOCAB_TYPE_SPM:
+            {
+                // OG tokenizer behavior:
+                //
+                // tokenizer.encode('', add_special_tokens=True)  returns [1]
+                // tokenizer.encode('', add_special_tokens=False) returns []
+
+                bool is_prev_special = true;  // prefix with space if first token
+
+                if (add_special && add_bos) {
+                    GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_bos_id);
+                    is_prev_special = true;
+                }
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text;
+
+                        // prefix with space if previous is special
+                        if (add_space_prefix && is_prev_special) {
+                            text = ' ';
+                        }
+
+                        text += fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        llama_escape_whitespace(text);
+                        llm_tokenizer_spm_session session(vocab);
+                        session.tokenize(text, output);
+                        is_prev_special = false;
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                        is_prev_special = true;
+                    }
+                }
+
+                if (add_special && add_bos && output.size() >= 2 && output[1] == special_bos_id) {
+                    LLAMA_LOG_WARN(
+                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
+                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
+                        "Are you sure this is what you want?\n", __FUNCTION__);
+                }
+
+                if (add_special && add_eos) {
+                    GGML_ASSERT(special_eos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_eos_id);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_BPE:
+            {
+                llm_tokenizer_bpe_session session(vocab, *static_cast(tokenizer.get()));
+                // it calls some other methods that are not exist in llm_tokenizer,
+                // here just cast it to bpe tokenizer object
+                if (add_special) {
+                    session.append_bos(output);
+                }
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        session.append(fragment.token, output);
+                    }
+                }
+
+                if (add_special) {
+                    session.append_eos(output);
+                    session.check_double_bos_eos(output);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_WPM:
+            {
+                if (add_special) {
+                    GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_bos_id);
+                }
+
+                llm_tokenizer_wpm_session session(vocab);
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+
+                if (add_special) {
+                    GGML_ASSERT(special_sep_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_sep_id);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_UGM:
+            {
+                if (add_special && add_bos) {
+                    GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_bos_id);
+                }
+                llm_tokenizer_ugm_session session(vocab, *static_cast(tokenizer.get()));
+
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+
+                if (add_special && add_bos && output.size() >= 2 && output[1] == special_bos_id) {
+                    LLAMA_LOG_WARN(
+                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
+                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
+                        "Are you sure this is what you want?\n", __FUNCTION__);
+                }
+
+                if (add_special && add_eos) {
+                    GGML_ASSERT(special_eos_id != LLAMA_TOKEN_NULL);
+                    output.push_back(special_eos_id);
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_RWKV:
+            {
+                llm_tokenizer_rwkv_session session(vocab, *static_cast(tokenizer.get()));
+                for (const auto & fragment : fragment_buffer) {
+                    if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
+                        std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
+
+#ifdef PRETOKENIZERDEBUG
+                        LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
+#endif
+
+                        session.tokenize(text, output);
+                    } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
+                        output.push_back(fragment.token);
+                    }
+                }
+            } break;
+        case LLAMA_VOCAB_TYPE_NONE:
+            GGML_ABORT("fatal error");
+    }
+
+    return output;
+}
+
+int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) const {
     // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843
     static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL;
-    const llama_token_attr attr = llama_token_get_attr_impl(vocab, token);
+    const llama_token_attr attr = token_get_attr(token);
     if (!special && (attr & attr_special)) {
         return 0;
     }
@@ -1791,7 +2503,7 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
 
     // if we have a cache - use it
     {
-        const auto & cache = vocab.cache_token_to_piece;
+        const auto & cache = cache_token_to_piece;
 
         if (!cache.empty()) {
             const auto & result = cache.at(token);
@@ -1799,9 +2511,9 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
         }
     }
 
-    if (0 <= token && token < (int32_t) vocab.id_to_token.size()) {
-        const std::string & token_text = vocab.id_to_token[token].text;
-        switch (llama_vocab_get_type(vocab)) {
+    if (0 <= token && token < (int32_t) id_to_token.size()) {
+        const std::string & token_text = id_to_token[token].text;
+        switch (get_type()) {
             case LLAMA_VOCAB_TYPE_WPM:
             case LLAMA_VOCAB_TYPE_SPM:
             case LLAMA_VOCAB_TYPE_UGM: {
@@ -1816,7 +2528,7 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
                     return _try_copy(result.data(), result.size());
                 }
                 if (attr & LLAMA_TOKEN_ATTR_BYTE) {
-                    char byte = (char) llama_token_to_byte(vocab, token);
+                    char byte = (char) token_to_byte(token);
                     return _try_copy((char*) &byte, 1);
                 }
                 break;
@@ -1852,43 +2564,46 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
     return 0;
 }
 
-int32_t llama_detokenize_impl(
-        const struct llama_vocab & vocab,
+const std::string & llama_vocab::impl::token_to_piece(llama_token token) const {
+    return cache_token_to_piece.at(token);
+}
+
+int32_t llama_vocab::impl::detokenize(
                const llama_token * tokens,
                          int32_t   n_tokens,
                             char * text,
                          int32_t   text_len_max,
                             bool   remove_special,
-                            bool   unparse_special) {
-    if (vocab.type == LLAMA_VOCAB_TYPE_NONE) {
+                            bool   unparse_special) const {
+    if (type == LLAMA_VOCAB_TYPE_NONE) {
         return 0;
     }
 
-    GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
+    GGML_ASSERT(tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
 
     int32_t avail = text_len_max;
     int32_t total = 0;
 
     // remove the leading space
-    bool remove_space = vocab.tokenizer_add_space_prefix;
+    bool remove_space = add_space_prefix;
 
-    if (remove_special && vocab.tokenizer_add_bos) {
-        if (n_tokens > 0 && tokens[0] == vocab.special_bos_id) {
+    if (remove_special && add_bos) {
+        if (n_tokens > 0 && tokens[0] == special_bos_id) {
             remove_space = false;
             n_tokens--;
             tokens++;
         }
     }
 
-    if (remove_special && vocab.tokenizer_add_eos) {
-        if (n_tokens > 0 && tokens[n_tokens - 1] == vocab.special_eos_id) {
+    if (remove_special && add_eos) {
+        if (n_tokens > 0 && tokens[n_tokens - 1] == special_eos_id) {
             n_tokens--;
         }
     }
 
     for (int32_t i = 0; i < n_tokens; ++i) {
         GGML_ASSERT(avail >= 0);
-        int32_t n_chars = llama_token_to_piece_impl(vocab, tokens[i], text, avail, remove_space, unparse_special);
+        int32_t n_chars = token_to_piece(tokens[i], text, avail, remove_space, unparse_special);
         remove_space = false;
         if (n_chars < 0) {
             avail = 0;
@@ -1904,7 +2619,7 @@ int32_t llama_detokenize_impl(
         return -total;
     }
 
-    if (vocab.tokenizer_clean_spaces) {
+    if (clean_spaces) {
         text -= total;  // restart text
 
         // first pass: characters ?!.,  //TODO: where do these characters come from?
@@ -1965,13 +2680,321 @@ int32_t llama_detokenize_impl(
     return total <= text_len_max ? total : -total;
 }
 
-std::string llama_detokenize(const struct llama_vocab & vocab, const std::vector & tokens, bool special) {
+void llama_vocab::impl::print_info() const {
+    LLAMA_LOG_INFO("%s: vocab type       = %s\n",     __func__, type_name().c_str());
+    LLAMA_LOG_INFO("%s: n_vocab          = %u\n",     __func__, vocab.n_tokens());
+    LLAMA_LOG_INFO("%s: n_merges         = %u\n",     __func__, (uint32_t) bpe_ranks.size());
+
+    // special tokens
+    if (special_bos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, special_bos_id,     id_to_token[special_bos_id].text.c_str() );  }
+    if (special_eos_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, special_eos_id,     id_to_token[special_eos_id].text.c_str() );  }
+    if (special_eot_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, special_eot_id,     id_to_token[special_eot_id].text.c_str() );  }
+    if (special_eom_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: EOM token        = %d '%s'\n", __func__, special_eom_id,     id_to_token[special_eom_id].text.c_str() );  }
+    if (special_unk_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, special_unk_id,     id_to_token[special_unk_id].text.c_str() );  }
+    if (special_sep_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, special_sep_id,     id_to_token[special_sep_id].text.c_str() );  }
+    if (special_pad_id  != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, special_pad_id,     id_to_token[special_pad_id].text.c_str() );  }
+    if (special_mask_id != LLAMA_TOKEN_NULL)    { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, special_mask_id,    id_to_token[special_mask_id].text.c_str() ); }
+
+    if (linefeed_id != LLAMA_TOKEN_NULL)        { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, linefeed_id,        id_to_token[linefeed_id].text.c_str() ); }
+
+    if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token    = %d '%s'\n", __func__, special_fim_pre_id, id_to_token[special_fim_pre_id].text.c_str() ); }
+    if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token    = %d '%s'\n", __func__, special_fim_suf_id, id_to_token[special_fim_suf_id].text.c_str() ); }
+    if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token    = %d '%s'\n", __func__, special_fim_mid_id, id_to_token[special_fim_mid_id].text.c_str() ); }
+    if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token    = %d '%s'\n", __func__, special_fim_pad_id, id_to_token[special_fim_pad_id].text.c_str() ); }
+    if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token    = %d '%s'\n", __func__, special_fim_rep_id, id_to_token[special_fim_rep_id].text.c_str() ); }
+    if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token    = %d '%s'\n", __func__, special_fim_sep_id, id_to_token[special_fim_sep_id].text.c_str() ); }
+
+    for (const auto & id : special_eog_ids) {
+        LLAMA_LOG_INFO( "%s: EOG token        = %d '%s'\n", __func__, id, id_to_token[id].text.c_str() );
+    }
+
+    LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len);
+}
+
+llama_vocab::llama_vocab() : pimpl(new impl(*this)) {
+}
+
+llama_vocab::~llama_vocab() {
+}
+
+void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
+    pimpl->load(ml, kv);
+}
+
+enum llama_vocab_type llama_vocab::get_type() const {
+    return pimpl->type;
+}
+
+enum llama_vocab_pre_type llama_vocab::get_pre_type() const {
+    return pimpl->pre_type;
+}
+
+uint32_t llama_vocab::n_tokens() const {
+    return (uint32_t) pimpl->id_to_token.size();
+}
+
+uint32_t llama_vocab::n_token_types() const {
+    return (uint32_t) pimpl->n_token_types;
+}
+
+std::string llama_vocab::type_name() const{
+    return pimpl->type_name();
+}
+
+bool llama_vocab::is_normal(llama_token id) const {
+    return pimpl->is_normal(id);
+}
+
+bool llama_vocab::is_unknown(llama_token id) const {
+    return pimpl->is_unknown(id);
+}
+
+bool llama_vocab::is_control(llama_token id) const {
+    return pimpl->is_control(id);
+}
+
+bool llama_vocab::is_byte(llama_token id) const {
+    return pimpl->is_byte(id);
+}
+
+bool llama_vocab::is_user_defined(llama_token id) const {
+    return pimpl->is_user_defined(id);
+}
+
+bool llama_vocab::is_unused(llama_token id) const {
+    return pimpl->is_unused(id);
+}
+
+bool llama_vocab::is_eog(llama_token id) const {
+    return pimpl->is_eog(id);
+}
+
+uint8_t llama_vocab::token_to_byte(llama_token id) const {
+    return pimpl->token_to_byte(id);
+}
+
+llama_token llama_vocab::byte_to_token(uint8_t ch) const {
+    GGML_ASSERT(get_type() != LLAMA_VOCAB_TYPE_NONE);
+    static const char * hex = "0123456789ABCDEF";
+    switch (get_type()) {
+        case LLAMA_VOCAB_TYPE_SPM:
+        case LLAMA_VOCAB_TYPE_UGM: {
+            const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 };
+            auto token = pimpl->token_to_id.find(buf);
+            if (token != pimpl->token_to_id.end()) {
+                return (*token).second;
+            }
+            // Try to fall back to just the byte as a string
+            const char buf2[2] = { (char)ch, 0 };
+            return pimpl->token_to_id.at(buf2);
+        }
+        case LLAMA_VOCAB_TYPE_WPM:
+        case LLAMA_VOCAB_TYPE_BPE: {
+            return pimpl->token_to_id.at(unicode_byte_to_utf8(ch));
+        }
+        default:
+            GGML_ABORT("fatal error");
+    }
+}
+
+llama_token llama_vocab::text_to_token(const std::string & text) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    auto it = pimpl->token_to_id.find(text);
+    if (it != pimpl->token_to_id.end()) {
+        return (*it).second;
+    }
+    return LLAMA_TOKEN_NULL;
+}
+
+const llama_vocab::token_data & llama_vocab::get_token_data(llama_token id) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    return pimpl->id_to_token.at(id);
+}
+
+const char * llama_vocab::token_get_text(llama_token id) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    return pimpl->id_to_token.at(id).text.c_str();
+}
+
+float llama_vocab::token_get_score(llama_token id) const {
+    GGML_ASSERT(pimpl->type != LLAMA_VOCAB_TYPE_NONE);
+    return pimpl->id_to_token.at(id).score;
+}
+
+llama_token_attr llama_vocab::token_get_attr(llama_token id) const {
+    return pimpl->token_get_attr(id);
+}
+
+llama_token llama_vocab::token_bos() const {
+    return pimpl->special_bos_id;
+}
+
+llama_token llama_vocab::token_eos() const {
+    return pimpl->special_eos_id;
+}
+
+llama_token llama_vocab::token_eot() const {
+    return pimpl->special_eot_id;
+}
+
+llama_token llama_vocab::token_eom() const {
+    return pimpl->special_eom_id;
+}
+
+llama_token llama_vocab::token_unk() const {
+    return pimpl->special_unk_id;
+}
+
+llama_token llama_vocab::token_sep() const {
+    return pimpl->special_sep_id;
+}
+
+llama_token llama_vocab::token_nl() const {
+    return pimpl->linefeed_id;
+}
+
+llama_token llama_vocab::token_pad() const {
+    return pimpl->special_pad_id;
+}
+
+llama_token llama_vocab::token_prefix() const {
+    return pimpl->special_fim_pre_id;
+}
+
+llama_token llama_vocab::token_middle() const {
+    return pimpl->special_fim_mid_id;
+}
+
+llama_token llama_vocab::token_suffix() const {
+    return pimpl->special_fim_suf_id;
+}
+
+llama_token llama_vocab::token_fim_pre() const {
+    return pimpl->special_fim_pre_id;
+}
+
+llama_token llama_vocab::token_fim_suf() const {
+    return pimpl->special_fim_suf_id;
+}
+
+llama_token llama_vocab::token_fim_mid() const {
+    return pimpl->special_fim_mid_id;
+}
+
+llama_token llama_vocab::token_fim_pad() const {
+    return pimpl->special_fim_pad_id;
+}
+
+llama_token llama_vocab::token_fim_rep() const {
+    return pimpl->special_fim_rep_id;
+}
+
+llama_token llama_vocab::token_fim_sep() const {
+    return pimpl->special_fim_sep_id;
+}
+
+bool llama_vocab::get_add_space_prefix() const {
+    return pimpl->add_space_prefix;
+}
+
+bool llama_vocab::get_add_bos() const {
+    return pimpl->add_bos;
+}
+
+bool llama_vocab::get_add_eos() const {
+    return pimpl->add_eos;
+}
+
+bool llama_vocab::get_ignore_merges() const {
+    return pimpl->ignore_merges;
+}
+
+bool llama_vocab::get_clean_spaces() const {
+    return pimpl->clean_spaces;
+}
+
+bool llama_vocab::get_remove_extra_whitespaces() const {
+    return pimpl->remove_extra_whitespaces;
+}
+
+bool llama_vocab::get_escape_whitespaces() const {
+    return pimpl->escape_whitespaces;
+}
+
+bool llama_vocab::get_treat_whitespace_as_suffix() const {
+    return pimpl->treat_whitespace_as_suffix;
+}
+
+int llama_vocab::max_token_len() const {
+    return pimpl->max_token_len;
+}
+
+int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
+    GGML_ASSERT(token_left.find(' ')   == std::string::npos);
+    GGML_ASSERT(token_left.find('\n')  == std::string::npos);
+    GGML_ASSERT(token_right.find(' ')  == std::string::npos);
+    GGML_ASSERT(token_right.find('\n') == std::string::npos);
+
+    auto it = pimpl->bpe_ranks.find(std::make_pair(token_left, token_right));
+    if (it == pimpl->bpe_ranks.end()) {
+        return -1;
+    }
+
+    return it->second;
+}
+
+int32_t llama_vocab::tokenize(
+                  const char * text,
+                     int32_t   text_len,
+                 llama_token * tokens,
+                     int32_t   n_tokens_max,
+                        bool   add_special,
+                        bool   parse_special) const {
+    auto res = tokenize(std::string(text, text_len), add_special, parse_special);
+    if (n_tokens_max < (int) res.size()) {
+        // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
+        return -((int) res.size());
+    }
+
+    for (size_t i = 0; i < res.size(); i++) {
+        tokens[i] = res[i];
+    }
+
+    return res.size();
+}
+
+std::vector llama_vocab::tokenize(
+        const std::string & raw_text,
+        bool add_special,
+        bool parse_special) const {
+    return pimpl->tokenize(raw_text, add_special, parse_special);
+}
+
+const std::string & llama_vocab::token_to_piece(llama_token token) const {
+    return pimpl->token_to_piece(token);
+}
+
+int32_t llama_vocab::token_to_piece(llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) const {
+    return pimpl->token_to_piece(token, buf, length, lstrip, special);
+}
+
+int32_t llama_vocab::detokenize(
+               const llama_token * tokens,
+                         int32_t   n_tokens,
+                            char * text,
+                         int32_t   text_len_max,
+                            bool   remove_special,
+                            bool   unparse_special) const {
+    return pimpl->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
+}
+
+std::string llama_vocab::detokenize(const std::vector & tokens, bool special) const {
     std::string text;
     text.resize(std::max(text.capacity(), tokens.size()));
-    int32_t n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
+    int32_t n_chars = detokenize(tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
     if (n_chars < 0) {
         text.resize(-n_chars);
-        n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
+        n_chars = detokenize(tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
         GGML_ASSERT(n_chars <= (int32_t)text.size());  // whitespace trimming is performed after per-token detokenization
     }
 
@@ -1980,3 +3003,243 @@ std::string llama_detokenize(const struct llama_vocab & vocab, const std::vector
     // NOTE: the original tokenizer decodes bytes after collecting the pieces.
     return text;
 }
+
+void llama_vocab::print_info() const {
+    pimpl->print_info();
+}
+
+//
+// interface implementation
+//
+
+int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab) {
+    return vocab->n_tokens();
+}
+
+// deprecated
+int32_t llama_n_vocab(const struct llama_vocab * vocab) {
+    return llama_vocab_n_tokens(vocab);
+}
+
+enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab) {
+    return vocab->get_type();
+}
+
+const char * llama_vocab_get_text(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->token_get_text(token);
+}
+
+float llama_vocab_get_score(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->token_get_score(token);
+}
+
+enum llama_token_attr llama_vocab_get_attr(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->token_get_attr(token);
+}
+
+bool llama_vocab_is_eog(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->is_eog(token);
+}
+
+bool llama_vocab_is_control(const struct llama_vocab * vocab, llama_token token) {
+    return vocab->is_control(token);
+}
+
+llama_token llama_vocab_bos(const struct llama_vocab * vocab) {
+    return vocab->token_bos();
+}
+
+llama_token llama_vocab_eos(const struct llama_vocab * vocab) {
+    return vocab->token_eos();
+}
+
+llama_token llama_vocab_eot(const struct llama_vocab * vocab) {
+    return vocab->token_eot();
+}
+
+// deprecated
+llama_token llama_vocab_cls(const struct llama_vocab * vocab) {
+    return vocab->token_bos();
+}
+
+llama_token llama_vocab_sep(const struct llama_vocab * vocab) {
+    return vocab->token_sep();
+}
+
+llama_token llama_vocab_nl (const struct llama_vocab * vocab) {
+    return vocab->token_nl();
+}
+
+llama_token llama_vocab_pad(const struct llama_vocab * vocab) {
+    return vocab->token_pad();
+}
+
+bool llama_vocab_get_add_bos(const struct llama_vocab * vocab) {
+    return vocab->get_add_bos();
+}
+
+bool llama_vocab_get_add_eos(const struct llama_vocab * vocab) {
+    return vocab->get_add_eos();
+}
+
+llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab) {
+    return vocab->token_fim_pre();
+}
+
+llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab) {
+    return vocab->token_fim_suf();
+}
+
+llama_token llama_vocab_fim_mid(const struct llama_vocab * vocab) {
+    return vocab->token_fim_mid();
+}
+
+llama_token llama_vocab_fim_pad(const struct llama_vocab * vocab) {
+    return vocab->token_fim_pad();
+}
+
+llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab) {
+    return vocab->token_fim_rep();
+}
+
+llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab) {
+    return vocab->token_fim_sep();
+}
+
+// deprecated
+const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_get_text(vocab, token);
+}
+
+// deprecated
+float llama_token_get_score(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_get_score(vocab, token);
+}
+
+// deprecated
+enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_get_attr(vocab, token);
+}
+
+// deprecated
+bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_is_eog(vocab, token);
+}
+
+// deprecated
+bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token) {
+    return llama_vocab_is_control(vocab, token);
+}
+
+// deprecated
+llama_token llama_token_bos(const struct llama_vocab * vocab) {
+    return llama_vocab_bos(vocab);
+}
+
+// deprecated
+llama_token llama_token_eos(const struct llama_vocab * vocab) {
+    return llama_vocab_eos(vocab);
+}
+
+// deprecated
+llama_token llama_token_eot(const struct llama_vocab * vocab) {
+    return llama_vocab_eot(vocab);
+}
+
+// deprecated
+llama_token llama_token_cls(const struct llama_vocab * vocab) {
+    //return llama_vocab_cls(vocab);
+    return llama_vocab_bos(vocab); // avoid deprecation warning
+}
+
+// deprecated
+llama_token llama_token_sep(const struct llama_vocab * vocab) {
+    return llama_vocab_sep(vocab);
+}
+
+// deprecated
+llama_token llama_token_nl (const struct llama_vocab * vocab) {
+    return llama_vocab_nl(vocab);
+}
+
+// deprecated
+llama_token llama_token_pad(const struct llama_vocab * vocab) {
+    return llama_vocab_pad(vocab);
+}
+
+// deprecated
+bool llama_add_bos_token(const struct llama_vocab * vocab) {
+    return llama_vocab_get_add_bos(vocab);
+}
+
+// deprecated
+bool llama_add_eos_token(const struct llama_vocab * vocab) {
+    return llama_vocab_get_add_eos(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_pre(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_pre(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_suf(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_suf(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_mid(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_mid(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_pad(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_pad(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_rep(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_rep(vocab);
+}
+
+// deprecated
+llama_token llama_token_fim_sep(const struct llama_vocab * vocab) {
+    return llama_vocab_fim_sep(vocab);
+}
+
+//
+// tokenization
+//
+
+int32_t llama_tokenize(
+    const struct llama_vocab * vocab,
+                  const char * text,
+                     int32_t   text_len,
+                 llama_token * tokens,
+                     int32_t   n_tokens_max,
+                        bool   add_special,
+                        bool   parse_special) {
+    return vocab->tokenize(text, text_len, tokens, n_tokens_max, add_special, parse_special);
+}
+
+int32_t llama_token_to_piece(
+    const struct llama_vocab * vocab,
+                 llama_token   token,
+                        char * buf,
+                     int32_t   length,
+                     int32_t   lstrip,
+                        bool   special) {
+    return vocab->token_to_piece(token, buf, length, lstrip, special);
+}
+
+int32_t llama_detokenize(
+    const struct llama_vocab * vocab,
+           const llama_token * tokens,
+                     int32_t   n_tokens,
+                        char * text,
+                     int32_t   text_len_max,
+                        bool   remove_special,
+                        bool   unparse_special) {
+    return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
+}
+
diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h
index 0d00086da1a..5ce35521434 100644
--- a/examples/talk-llama/llama-vocab.h
+++ b/examples/talk-llama/llama-vocab.h
@@ -4,179 +4,122 @@
 
 #include 
 #include 
-#include 
-#include 
-#include 
-
-static const char * llama_model_vocab_type_name(enum llama_vocab_type type){
-    switch (type) {
-        case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
-        case LLAMA_VOCAB_TYPE_SPM:  return "SPM";
-        case LLAMA_VOCAB_TYPE_BPE:  return "BPE";
-        case LLAMA_VOCAB_TYPE_WPM:  return "WPM";
-        case LLAMA_VOCAB_TYPE_UGM:  return "UGM";
-        case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
-        default:                    return "unknown";
-    }
-}
-
-struct llm_tokenizer;
+#include 
 
-struct llama_vocab {
-    using id    = llama_token;
-    using token = std::string;
-    using tattr = llama_token_attr;
+struct LLM_KV;
+struct llama_model_loader;
 
+struct llama_vocab {
     struct token_data {
-        token text;
-        float score;
-        tattr attr;
+        std::string      text;
+        float            score;
+        llama_token_attr attr;
     };
 
-    uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab
+    llama_vocab();
+    ~llama_vocab();
+
+    void load(llama_model_loader & ml, const LLM_KV & kv);
 
-    enum llama_vocab_type     type     = LLAMA_VOCAB_TYPE_SPM;
-    enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+    enum llama_vocab_type     get_type()     const;
+    enum llama_vocab_pre_type get_pre_type() const;
 
-    int max_token_len = 0; // used for optimizing longest token search
+    uint32_t n_tokens() const;
+    uint32_t n_token_types() const;
 
-    std::unordered_map token_to_id;
-    std::vector       id_to_token;
+    std::string type_name() const;
 
-    std::vector    cache_special_tokens;
-    std::vector cache_token_to_piece; // llama_token_to_piece(special = true);
+    bool is_normal      (llama_token id) const;
+    bool is_unknown     (llama_token id) const;
+    bool is_control     (llama_token id) const;
+    bool is_byte        (llama_token id) const;
+    bool is_user_defined(llama_token id) const;
+    bool is_unused      (llama_token id) const;
+    bool is_eog         (llama_token id) const;
 
-    std::map, int> bpe_ranks;
+    uint8_t     token_to_byte(llama_token id) const;
+    llama_token byte_to_token(uint8_t ch)     const;
 
-    // default LLaMA special tokens
-    // TODO: should we set all of these to LLAMA_TOKEN_NULL?
-    id special_bos_id  = 1;
-    id special_eos_id  = 2;
-    id special_eot_id  = LLAMA_TOKEN_NULL;
-    id special_eom_id  = LLAMA_TOKEN_NULL;
-    id special_unk_id  = 0;
-    id special_sep_id  = LLAMA_TOKEN_NULL;
-    id special_pad_id  = LLAMA_TOKEN_NULL;
-    id special_cls_id  = LLAMA_TOKEN_NULL; // TODO: revisit if this is really needed https://github.com/ggerganov/llama.cpp/pull/10930
-    id special_mask_id = LLAMA_TOKEN_NULL;
+    llama_token text_to_token(const std::string & text) const;
 
-    id linefeed_id = 13;
+    const token_data & get_token_data(llama_token id) const;
 
-    // fim tokens
-    id special_fim_pre_id = LLAMA_TOKEN_NULL;
-    id special_fim_suf_id = LLAMA_TOKEN_NULL;
-    id special_fim_mid_id = LLAMA_TOKEN_NULL;
-    id special_fim_pad_id = LLAMA_TOKEN_NULL;
-    id special_fim_rep_id = LLAMA_TOKEN_NULL; // repo
-    id special_fim_sep_id = LLAMA_TOKEN_NULL; // file separator
+    const char *     token_get_text (llama_token id) const;
+    float            token_get_score(llama_token id) const;
+    llama_token_attr token_get_attr (llama_token id) const;
 
-    // set of all tokens that cause "end of generation"
-    std::set special_eog_ids;
+    llama_token token_bos() const;
+    llama_token token_eos() const;
+    llama_token token_eot() const;
+    llama_token token_eom() const;
+    llama_token token_unk() const;
+    llama_token token_sep() const;
+    llama_token token_nl () const;
+    llama_token token_pad() const;
 
-    // tokenizer flags
-    bool tokenizer_add_space_prefix           = false;
-    bool tokenizer_add_bos                    = false;
-    bool tokenizer_add_eos                    = false;
-    bool tokenizer_ignore_merges              = false;
-    bool tokenizer_clean_spaces               = false;  // clean_up_tokenization_spaces
-    bool tokenizer_remove_extra_whitespaces   = false;
-    bool tokenizer_escape_whitespaces         = true;
-    bool tokenizer_treat_whitespace_as_suffix = false;
+    llama_token token_prefix() const;
+    llama_token token_middle() const;
+    llama_token token_suffix() const;
 
-    std::vector precompiled_charsmap;
+    llama_token token_fim_pre() const;
+    llama_token token_fim_suf() const;
+    llama_token token_fim_mid() const;
+    llama_token token_fim_pad() const;
+    llama_token token_fim_rep() const;
+    llama_token token_fim_sep() const;
 
-    llm_tokenizer * tokenizer = nullptr;
+    bool get_add_space_prefix          () const;
+    bool get_add_bos                   () const;
+    bool get_add_eos                   () const;
+    bool get_ignore_merges             () const;
+    bool get_clean_spaces              () const;
+    bool get_remove_extra_whitespaces  () const;
+    bool get_escape_whitespaces        () const;
+    bool get_treat_whitespace_as_suffix() const;
 
-    llama_vocab() = default;
-    ~llama_vocab();
+    int max_token_len() const;
 
     int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
 
-    void init_tokenizer();
+    int32_t tokenize(
+                   const char * text,
+                      int32_t   text_len,
+                  llama_token * tokens,
+                      int32_t   n_tokens_max,
+                         bool   add_special,
+                         bool   parse_special) const;
+
+    std::vector tokenize(
+            const std::string & raw_text,
+                         bool   add_special,
+                         bool   parse_special = false) const;
+
+    // does not write null-terminator to buf
+    int32_t token_to_piece(
+                  llama_token   token,
+                         char * buf,
+                      int32_t   length,
+                      int32_t   lstrip,
+                         bool   special) const;
+
+    // use cached data
+    const std::string & token_to_piece(llama_token token) const;
+
+    int32_t detokenize(
+            const llama_token * tokens,
+                      int32_t   n_tokens,
+                         char * text,
+                      int32_t   text_len_max,
+                         bool   remove_special,
+                         bool   unparse_special) const;
+
+    std::string detokenize(
+            const std::vector & tokens,
+                                      bool   special) const;
+
+    void print_info() const;
+
+private:
+    struct impl;
+    std::unique_ptr pimpl;
 };
-
-//
-// internal API
-//
-
-// TODO: rename to llama_tokenize_impl
-// TODO: This should probably be in llama.h
-std::vector llama_tokenize_internal(
-        const llama_vocab & vocab,
-        std::string raw_text,
-        bool add_special,
-        bool parse_special = false);
-
-// TODO: move the API below as member functions of llama_vocab
-llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
-
-const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
-
-float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token);
-
-llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token);
-
-bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token);
-
-bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token);
-
-llama_token llama_token_bos_impl(const struct llama_vocab & vocab);
-llama_token llama_token_eos_impl(const struct llama_vocab & vocab);
-llama_token llama_token_eot_impl(const struct llama_vocab & vocab);
-llama_token llama_token_eom_impl(const struct llama_vocab & vocab);
-llama_token llama_token_cls_impl(const struct llama_vocab & vocab);
-llama_token llama_token_sep_impl(const struct llama_vocab & vocab);
-llama_token llama_token_nl_impl (const struct llama_vocab & vocab);
-llama_token llama_token_pad_impl(const struct llama_vocab & vocab);
-
-llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
-llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
-llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
-
-llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab);
-llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab);
-
-bool llama_add_bos_token_impl(const struct llama_vocab & vocab);
-bool llama_add_eos_token_impl(const struct llama_vocab & vocab);
-
-int32_t llama_tokenize_impl(
-        const struct llama_vocab & vocab,
-                      const char * text,
-                         int32_t   text_len,
-                     llama_token * tokens,
-                         int32_t   n_tokens_max,
-                            bool   add_special,
-                            bool   parse_special);
-
-// does not write null-terminator to buf
-int32_t llama_token_to_piece_impl(
-        const struct llama_vocab & vocab,
-                     llama_token   token,
-                            char * buf,
-                         int32_t   length,
-                         int32_t   lstrip,
-                            bool   special);
-
-// check if token0 is contained as a prefix in token1
-bool llama_token_is_prefix_impl(
-        const struct llama_vocab & vocab,
-                     llama_token   token0,
-                     llama_token   token1);
-
-int32_t llama_detokenize_impl(
-        const struct llama_vocab & vocab,
-               const llama_token * tokens,
-                         int32_t   n_tokens,
-                            char * text,
-                         int32_t   text_len_max,
-                            bool   remove_special,
-                            bool   unparse_special);
-
-std::string llama_detokenize(
-        const struct llama_vocab & vocab,
-  const std::vector & tokens,
-                            bool   special);
diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp
index ebd6e3b2941..daf1b7c97cd 100644
--- a/examples/talk-llama/llama.cpp
+++ b/examples/talk-llama/llama.cpp
@@ -8,7 +8,6 @@
 #include "llama-kv-cache.h"
 #include "llama-model-loader.h"
 #include "llama-model.h"
-#include "llama-quant.h"
 
 #include "ggml.h"
 #include "ggml-alloc.h"
@@ -18,2485 +17,60 @@
 #include 
 #include 
 #include 
-#include 
 #include 
-#include 
-#include 
 #include 
-#include 
 #include 
 #include 
 #include 
 #include 
 #include 
 #include 
-#include 
-#include 
-#include 
-#include 
-#include 
 
 #if defined(_MSC_VER)
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
-//
-// tensor loading (TODO: add llama_tesor_loader?)
-//
-
-static int llama_get_device_count(const llama_model & model) {
-    return (int) model.devices.size();
-}
-
-// checks if the weight tensor can be used with the specified buffer type and device
-static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) {
-    GGML_ASSERT(w != nullptr);
-
-    if (op == GGML_OP_NONE) {
-        return true;
-    }
-
-    ggml_init_params params = {
-        /*.mem_size   =*/ ggml_tensor_overhead()*8,
-        /*.mem_buffer =*/ NULL,
-        /*.no_alloc   =*/ true,
-    };
-    ggml_context_ptr ctx_ptr { ggml_init(params) };
-    if (!ctx_ptr) {
-        throw std::runtime_error(format("failed to create ggml context"));
-    }
-    ggml_context * ctx = ctx_ptr.get();
-
-    ggml_tensor * op_tensor = nullptr;
-
-    switch (op) {
-        case GGML_OP_GET_ROWS:
-            {
-                ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512);
-                op_tensor = ggml_get_rows(ctx, w, b);
-            } break;
-        case GGML_OP_MUL_MAT:
-            {
-                ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], 512, w->ne[2], w->ne[3]);
-                op_tensor = ggml_mul_mat(ctx, w, b);
-            } break;
-        case GGML_OP_MUL_MAT_ID:
-            {
-                int n_expert_used = hparams.n_expert_used;
-                ggml_tensor * b = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512);
-                ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512);
-                op_tensor = ggml_mul_mat_id(ctx, w, b, ids);
-            } break;
-        case GGML_OP_ADD:
-            {
-                ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]);
-                op_tensor = ggml_add(ctx, a, w);
-            } break;
-        case GGML_OP_MUL:
-            {
-                ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]);
-                op_tensor = ggml_mul(ctx, a, w);
-            } break;
-        case GGML_OP_DIV:
-            {
-                ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, w->ne[0]);
-                op_tensor = ggml_div(ctx, a, w);
-            } break;
-        case GGML_OP_ROPE:
-            {
-                int n_embd_head = hparams.n_embd_head_v;
-                int n_head = hparams.n_head();
-                ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, 512);
-                ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512);
-                op_tensor = ggml_rope_ext(
-                    ctx, a, b, w,
-                    0, 0, 0, 0, 0,
-                    0, 0, 0, 0
-                );
-
-            } break;
-        case GGML_OP_SSM_CONV:
-            {
-                // FIXME
-                ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789);
-                op_tensor = ggml_ssm_conv(ctx, conv_x, w);
-            } break;
-        case GGML_OP_SSM_SCAN:
-            {
-                // FIXME
-                const int64_t d_state      = w->ne[0];
-                const int64_t d_inner      = w->ne[1];
-                const int64_t n_seq_tokens = 512;
-                const int64_t n_seqs       = 1;
-                ggml_tensor * s  = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs);
-                ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs);
-                ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs);
-                ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
-                ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
-                op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C);
-            } break;
-        case GGML_OP_RWKV_WKV6:
-            {
-                // FIXME
-                const int64_t S = 123;
-                const int64_t H = 123;
-                const int64_t n_tokens = 123;
-                const int64_t n_seqs = 123;
-                ggml_tensor  * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, 1, H, n_tokens);
-                ggml_tensor  * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
-                ggml_tensor  * r = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
-                ggml_tensor  * tf = w;
-                ggml_tensor  * td = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
-                ggml_tensor  * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H);
-                op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state);
-            } break;
-        case GGML_OP_IM2COL:
-            {
-                const int n_embd = hparams.n_embd;
-                ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1);
-                op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16);
-            } break;
-        default:
-            GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name);
-    }
-
-    // create a temporary dummy buffer for the weight so that supports_op can check the buffer type
-    GGML_ASSERT(w->buffer == nullptr);
-    w->buffer = ggml_backend_buft_alloc_buffer(buft, 0);
-    bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor);
-    ggml_backend_buffer_free(w->buffer);
-    w->buffer = nullptr;
-
-    return op_supported;
-}
-
-// find the first buffer type in the list that can use the tensor
-static ggml_backend_buffer_type_t select_weight_buft(const llama_model & model, ggml_tensor * tensor, ggml_op op, const llama_model::buft_list_t & buft_list) {
-    GGML_ASSERT(!buft_list.empty());
-    for (const auto & cur : buft_list) {
-        ggml_backend_dev_t cur_dev = cur.first;
-        ggml_backend_buffer_type_t cur_buft = cur.second;
-        if (weight_buft_supported(model.hparams, tensor, op, cur_buft, cur_dev)) {
-            return cur_buft;
-        }
-    }
-    return nullptr;
-}
-
-// CPU: ACCEL -> CPU extra -> GPU host -> CPU
-static llama_model::buft_list_t make_cpu_buft_list(llama_model & model) {
-    llama_model::buft_list_t buft_list;
-
-    // add ACCEL buffer types
-    for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
-        ggml_backend_dev_t dev = ggml_backend_dev_get(i);
-        if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
-            auto * buft = ggml_backend_dev_buffer_type(dev);
-            // skip
-            if (buft != ggml_backend_cpu_buffer_type()) {
-                buft_list.emplace_back(dev, buft);
-            }
-        }
-    }
-
-    // add extra buffer types
-    auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
-    auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
-    auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
-        ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
-    if (ggml_backend_dev_get_extra_bufts_fn) {
-        ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev);
-        while (extra_bufts && *extra_bufts) {
-            buft_list.emplace_back(cpu_dev, *extra_bufts);
-            ++extra_bufts;
-        }
-    }
-
-    // add a host buffer type
-    // storing the tensors in a host buffer is useful when the processing of large batches
-    // is offloaded to a GPU device, since it reduces the time spent on data transfers
-    // generally, this will be done using the first device in the list
-    // a better approach would be to handle this on a weight-by-weight basis using the offload_op
-    // function of the device to determine if it would benefit from being stored in a host buffer
-    for (auto * dev : model.devices) {
-        ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev);
-        if (buft) {
-            buft_list.emplace_back(dev, buft);
-            break;
-        }
-    }
-
-    // add the CPU buffer type
-    for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
-        ggml_backend_dev_t dev = ggml_backend_dev_get(i);
-        if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) {
-            buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev));
-        }
-    }
-
-    return buft_list;
-}
-
-// GPU: split if LLAMA_SPLIT_MODE_ROW -> GPU
-static llama_model::buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, enum llama_split_mode split_mode, const float * tensor_split) {
-    llama_model::buft_list_t buft_list;
-
-    // add the device split buffer type if requested and available
-    if (split_mode == LLAMA_SPLIT_MODE_ROW) {
-        ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev);
-        auto ggml_backend_split_buffer_type_fn = (ggml_backend_split_buffer_type_t)
-            ggml_backend_reg_get_proc_address(reg, "ggml_backend_split_buffer_type");
-        if (ggml_backend_split_buffer_type_fn) {
-            size_t dev_index = [&]() {
-                auto * reg = ggml_backend_dev_backend_reg(dev);
-                for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); ++i) {
-                    if (ggml_backend_reg_dev_get(reg, i) == dev) {
-                        return i;
-                    }
-                }
-                throw std::runtime_error(format("device %s not found in its backend reg", ggml_backend_dev_name(dev)));
-            }();
-            auto * buft = ggml_backend_split_buffer_type_fn(dev_index, tensor_split);
-            if (buft != nullptr) {
-                buft_list.emplace_back(dev, buft);
-            }
-        }
-    }
-
-    // add the device default buffer type
-    buft_list.emplace_back(dev, ggml_backend_dev_buffer_type(dev));
-
-    return buft_list;
-}
-
-// Returns false if cancelled by progress_callback
-static bool llm_load_tensors(
-        llama_model_loader & ml,
-        llama_model & model,
-        int n_gpu_layers,
-        enum llama_split_mode split_mode,
-        int main_gpu,
-        const float * tensor_split,
-        bool use_mlock,
-        llama_progress_callback progress_callback,
-        void * progress_callback_user_data) {
-    auto & hparams = model.hparams;
-
-    model.split_mode   = split_mode;
-    model.main_gpu     = main_gpu;
-    model.n_gpu_layers = n_gpu_layers;
-
-    const int n_layer = hparams.n_layer;
-
-    bool use_mmap_buffer = true;
-
-    // build a list of buffer types for the CPU and GPU devices
-    model.cpu_buft_list = make_cpu_buft_list(model);
-    for (auto * dev : model.devices) {
-        llama_model::buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split);
-        // add CPU buffer types as a fallback
-        buft_list.insert(buft_list.end(), model.cpu_buft_list.begin(), model.cpu_buft_list.end());
-        model.gpu_buft_list.emplace(dev, std::move(buft_list));
-    }
-
-    // calculate the split points
-    int device_count = llama_get_device_count(model);
-    bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + device_count, [](float x) { return x == 0.0f; });
-    std::vector splits(device_count);
-    if (all_zero) {
-        // default split, by free memory
-        for (int i = 0; i < device_count; ++i) {
-            ggml_backend_dev_t dev = model.devices[i];
-            size_t total;
-            size_t free;
-            ggml_backend_dev_memory(dev, &free, &total);
-            splits[i] = free;
-        }
-    } else {
-        std::copy(tensor_split, tensor_split + device_count, splits.begin());
-    }
-
-    // sum and normalize the splits to get the split points
-    float split_sum = 0.0f;
-    for (int i = 0; i < device_count; ++i) {
-        split_sum += splits[i];
-        splits[i] = split_sum;
-    }
-    for (int i = 0; i < device_count; ++i) {
-        splits[i] /= split_sum;
-    }
-
-    ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
-    const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
-    const int act_gpu_layers = model.devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1);
-    auto get_layer_buft_list = [&](int il) -> llama_model::layer_dev {
-        if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) {
-            return {cpu_dev, &model.cpu_buft_list};
-        }
-        int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(il - i_gpu_start)/act_gpu_layers) - splits.begin();
-        auto * dev = model.devices.at(layer_gpu);
-        return {dev, &model.gpu_buft_list.at(dev)};
-    };
-
-    // assign the input layer
-    // there is very little benefit to offloading the input layer, so always keep it on the CPU
-    model.dev_input = { cpu_dev, &model.cpu_buft_list };
-
-    // assign the repeating layers to the devices according to the splits
-    model.dev_layer.resize(n_layer);
-    for (int il = 0; il < n_layer; ++il) {
-        model.dev_layer[il] = get_layer_buft_list(il);
-    }
-    // assign the output layer
-    model.dev_output = get_layer_buft_list(n_layer);
-
-    // one ggml context per buffer type
-    int max_n_tensors = ml.n_tensors;
-    max_n_tensors += 1;         // duplicated output tensor
-    max_n_tensors += n_layer*2; // duplicated rope freq tensors
-    const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors;
-
-    std::map ctx_map;
-    auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
-        auto it = ctx_map.find(buft);
-        if (it == ctx_map.end()) {
-            ggml_init_params params = {
-                /*.mem_size   =*/ ctx_size,
-                /*.mem_buffer =*/ NULL,
-                /*.no_alloc   =*/ true,
-            };
-            ggml_context * ctx = ggml_init(params);
-            if (!ctx) {
-                throw std::runtime_error(format("failed to create ggml context"));
-            }
-            ctx_map[buft] = ctx;
-            model.ctxs.emplace_back(ctx);
-            return ctx;
-        }
-        return it->second;
-    };
-
-    // create tensors for the weights
-    {
-        // note: cast to int64_t since we will use these for the tensor dimensions
-        const int64_t n_head        = hparams.n_head();
-        const int64_t n_head_kv     = hparams.n_head_kv();
-        const int64_t n_embd        = hparams.n_embd;
-        const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa();
-        const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa();
-        const int64_t n_embd_head_k = hparams.n_embd_head_k;
-        const int64_t n_embd_head_v = hparams.n_embd_head_v;
-        const int64_t n_ff          = hparams.n_ff();
-        const int64_t n_embd_gqa    = n_embd_v_gqa;
-        const int64_t n_vocab       = hparams.n_vocab;
-        const int64_t n_vocab_type  = hparams.n_vocab_type;
-        const int64_t n_rot         = hparams.n_rot;
-        const int64_t n_expert      = hparams.n_expert;
-        const int64_t n_expert_used = hparams.n_expert_used;
-        const int64_t n_ctx_train   = hparams.n_ctx_train;
-
-        if (n_expert > 0 && hparams.n_expert_used == 0) {
-            throw std::runtime_error("model has expert layers but no expert layers are used");
-        }
-
-        int n_moved_tensors = 0;
-        ggml_tensor * first_moved_tensor = nullptr;
-        ggml_backend_buffer_type_t first_moved_from_buft = nullptr;
-        ggml_backend_buffer_type_t first_moved_to_buft = nullptr;
-
-        auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) -> ggml_tensor * {
-            ggml_tensor * t_meta = ml.get_tensor_meta(tn.str().c_str());
-
-            if (!t_meta) {
-                if (flags & llama_model_loader::TENSOR_NOT_REQUIRED) {
-                    return nullptr;
-                }
-                throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str()));
-            }
-
-            // some models use the token embedding tensor as the output, but since these are used in different layers and with different ops
-            // the tensor is duplicated
-            // to handle this, we check if the tensor is duplicated, and if so, we assume that it is being loaded as the output tensor
-            llm_tensor tn_tensor = tn.tensor;
-            if (tn.tensor == LLM_TENSOR_TOKEN_EMBD && flags & llama_model_loader::TENSOR_DUPLICATED) {
-                tn_tensor = LLM_TENSOR_OUTPUT;
-            }
-
-            llm_tensor_info info;
-            try {
-                info = llm_tensor_info_for(tn_tensor);
-            } catch (const std::out_of_range & e) {
-                throw std::runtime_error(format("missing tensor info mapping for %s", tn.str().c_str()));
-            }
-
-            // tensors with "bias" suffix are always used with GGML_OP_ADD
-            ggml_op op;
-            bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0;
-            if (bias) {
-                op = GGML_OP_ADD;
-            } else {
-                op = info.op;
-            }
-
-            // sanity checks
-            if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) {
-                if (tn.bid != -1) {
-                    GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str());
-                }
-            } else {
-                if (tn.bid == -1) {
-                    GGML_ABORT("repeating layer tensor %s used without a layer number", tn.str().c_str());
-                }
-            }
-
-            // select the buffer type for this tensor
-            llama_model::buft_list_t * buft_list;
-            switch (info.layer) {
-                case LLM_TENSOR_LAYER_INPUT:
-                    buft_list = model.dev_input.buft_list;
-                    break;
-                case LLM_TENSOR_LAYER_OUTPUT:
-                    buft_list = model.dev_output.buft_list;
-                    break;
-                case LLM_TENSOR_LAYER_REPEATING:
-                    buft_list = model.dev_layer.at(tn.bid).buft_list;
-                    break;
-                default:
-                    GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str());
-            }
-
-            ggml_backend_buffer_type_t buft = select_weight_buft(model, t_meta, op, *buft_list);
-            if (!buft) {
-                throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str()));
-            }
-
-            // avoid using a host buffer when using mmap
-            auto * buft_dev = ggml_backend_buft_get_device(buft);
-            if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) {
-                auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
-                buft = ggml_backend_dev_buffer_type(cpu_dev);
-            }
-
-            if (buft != buft_list->front().second) {
-                n_moved_tensors++;
-                if (!first_moved_tensor) {
-                    first_moved_tensor = t_meta;
-                    first_moved_from_buft = buft_list->front().second;
-                    first_moved_to_buft   = buft;
-                }
-            }
-
-            ggml_context * ctx = ctx_for_buft(buft);
-
-            // if duplicated, check if the original tensor was allocated in the same buffer type context and avoid creating a new one
-            if (flags & llama_model_loader::TENSOR_DUPLICATED) {
-                ggml_tensor * t = ggml_get_tensor(ctx, tn.str().c_str());
-                if (t) {
-                    return t;
-                }
-            }
-            return ml.create_tensor(ctx, tn, ne, flags);
-        };
-
-        model.layers.resize(n_layer);
-
-        // TODO: move to a separate function
-        const auto tn = LLM_TN(model.arch);
-        switch (model.arch) {
-            case LLM_ARCH_LLAMA:
-            case LLM_ARCH_REFACT:
-            case LLM_ARCH_MINICPM:
-            case LLM_ARCH_GRANITE:
-            case LLM_ARCH_GRANITE_MOE:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
-
-                        // optional bias tensors
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
-                            layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                            layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        }
-                        else {
-                            layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        }
-
-                        if (n_expert == 0) {
-                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-
-                            // optional MLP bias
-                            layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                            layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                            layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        } else {
-                            layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert}, 0);
-                            layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd,   n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                            layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert}, 0);
-                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert}, 0);
-                        }
-                    }
-                } break;
-            case LLM_ARCH_DECI:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-                        const int64_t n_embd_k_gqa  = hparams.n_embd_k_gqa(i);
-                        const int64_t n_embd_v_gqa  = hparams.n_embd_v_gqa(i);
-                        const int64_t n_embd_gqa    = hparams.n_embd_v_gqa(i);
-                        const int64_t n_ff          = hparams.n_ff(i);
-                        const int64_t n_head        = hparams.n_head(i);
-                        const int64_t n_head_kv     = hparams.n_head_kv(i);
-
-                        if (n_head_kv == 0 && n_head > 0) {
-                            // linear attention for DeciLMCausalModel
-                            layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                            layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        }
-                        else if (n_head_kv > 0) {
-                            layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
-                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                            layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
-                        }
-
-                        // optional bias tensors
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
-                            layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                            layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        }
-                        else {
-                            layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        }
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-
-                        // optional MLP bias
-                        layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    }
-                } break;
-            case LLM_ARCH_MINICPM3:
-                {
-                    const int64_t n_embd_head_qk_rope = hparams.n_rot;
-                    const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
-
-                    const int64_t q_lora_rank  = hparams.n_lora_q;
-                    const int64_t kv_lora_rank = hparams.n_lora_kv;
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0);
-
-                        layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0);
-
-                        layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0);
-                        layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0);
-
-                        layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
-                        layer.wkv_b     = create_tensor(tn(LLM_TENSOR_ATTN_KV_B,     "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
-                        layer.wo        = create_tensor(tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {              n_head * (                      n_embd_head_v), n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-
-                        layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                    }
-                } break;
-            case LLM_ARCH_GROK:
-                {
-                    if (n_expert == 0) {
-                        throw std::runtime_error("Grok model cannot have zero experts");
-                    }
-
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.attn_out_norm   = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert}, 0);
-                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert}, 0);
-                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert}, 0);
-
-                        layer.layer_out_norm   = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
-                    }
-                } break;
-            case LLM_ARCH_DBRX:
-                {
-                    if (n_expert == 0) {
-                        throw std::runtime_error("DBRX model cannot have zero experts");
-                    }
-
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert}, 0);
-                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff,   n_expert}, 0);
-                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff,   n_embd, n_expert}, 0);
-                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd, n_ff,   n_expert}, 0);
-                    }
-                } break;
-            case LLM_ARCH_BAICHUAN:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-                    {
-                        model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                        model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_FALCON:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    {
-                        model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                        model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-
-                        model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        if (!model.output) {
-                            model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU
-                        }
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.attn_norm_2   = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_STARCODER:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-                    model.pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train}, 0);
-
-                    // output
-                    {
-                        model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                        model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                        model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        if (!model.output) {
-                            // needs to be on GPU
-                            model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                        }
-
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
-
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i),   {n_embd, n_ff}, 0);
-                        layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i),     {n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_BERT:
-            case LLM_ARCH_NOMIC_BERT:
-                {
-                    model.tok_embd     = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, 0);
-                    model.type_embd    = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}, 0);
-
-                    if (model.arch == LLM_ARCH_BERT) {
-                        model.pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD,    "weight"), {n_embd, n_ctx_train}, 0);
-
-                        model.cls   = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        model.cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"),   {n_embd},         llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        model.cls_out   = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        model.cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"),   {1},         llama_model_loader::TENSOR_NOT_REQUIRED);
-                    }
-
-                    model.tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
-                    model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        if (model.arch == LLM_ARCH_BERT) {
-                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                            layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i),   {n_embd}, 0);
-
-                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                            layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i),   {n_embd_gqa}, 0);
-
-                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                            layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i),   {n_embd_gqa}, 0);
-                        } else {
-                            layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        }
-
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.attn_out_norm   = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,        "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN,      "weight", i), {n_ff, n_embd}, 0);
-
-                        if (model.arch == LLM_ARCH_BERT) {
-                            layer.bo         = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
-                            layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, 0);
-                            layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
-                        } else {
-                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
-                        }
-
-                        layer.layer_out_norm   = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
-                        layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i),   {n_embd}, 0);
-                    }
-                } break;
-            case LLM_ARCH_JINA_BERT_V2:
-                {
-                    model.tok_embd  = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, 0); // word_embeddings
-                    model.type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}, 0); // token_type_embeddings
-
-                    model.tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); // LayerNorm
-                    model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd}, 0); //LayerNorm bias
-
-                    model.cls   = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    model.cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"),   {1},         llama_model_loader::TENSOR_NOT_REQUIRED);
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i]; // JinaBertLayer
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i),   {n_embd}, 0);
-
-                        layer.attn_q_norm   = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias",   i), {n_embd_gqa}, 0);
-
-                        layer.attn_k_norm   = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias",   i), {n_embd_gqa}, 0);
-
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); //output_dens
-                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias",   i), {n_embd}, 0); //output_dens
-
-                        layer.attn_out_norm   = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); //output_norm
-                        layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias",   i), {n_embd}, 0);
-
-                        layer.attn_norm_2   = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias",   i), {n_embd}, 0);
-
-                        layer.layer_out_norm   = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
-                        layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias",   i), {n_embd}, 0);
-                    }
-                } break;
-            case LLM_ARCH_BLOOM:
-                {
-                    model.tok_embd   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab}, 0);
-                    model.tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
-                    model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {n_embd}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias",   i), {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias",   i), {n_embd + 2*n_embd_gqa}, 0);
-
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias",   i), {n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias",   i), {n_embd}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias",   i), {n_embd}, 0);
-
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias",   i), {n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_MPT:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-                    model.pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    if (!model.output) {
-                        model.output    = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // needs to be on GPU
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.attn_q_norm   = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.attn_k_norm   = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias",   i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        // AWQ ScaleActivation layer
-                        layer.ffn_act = create_tensor(tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    }
-                } break;
-            case LLM_ARCH_STABLELM:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm =   create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        // optional bias tensors, present in Stable LM 2 1.6B
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        // optional q and k layernorms, present in StableLM 2 12B
-                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head},    llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        // optional FFN norm, not present in StableLM 2 12B which uses parallel residual
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_QWEN:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd*3}, 0);
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff/2}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff/2, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff/2}, 0);
-                    }
-                } break;
-            case LLM_ARCH_QWEN2:
-            case LLM_ARCH_QWEN2VL:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        // optional bias tensors
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd}, 0);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, 0);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_QWEN2MOE:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        // optional bias tensors
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd}, 0);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, 0);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
-
-                        if (n_expert == 0) {
-                            throw std::runtime_error("n_expert must be > 0 for QWEN2MOE");
-                        }
-                        if (n_expert_used == 0) {
-                            throw std::runtime_error("n_expert_used must be > 0 for QWEN2MOE");
-                        }
-
-                        // MoE branch
-                        const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
-
-                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
-                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
-                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
-
-                        // Shared expert branch
-                        const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;
-
-                        layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {    n_embd, n_ff_shexp}, 0);
-                        layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp,     n_embd}, 0);
-                        layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {    n_embd, n_ff_shexp}, 0);
-                    }
-                } break;
-            case LLM_ARCH_PHI2:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-                    model.output_b      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "bias"),   {n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        if (layer.wqkv == nullptr) {
-                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
-                            layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i),   {n_embd}, 0);
-
-                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
-                            layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i),   {n_embd_gqa}, 0);
-
-                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
-                            layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i),   {n_embd_gqa}, 0);
-                        }
-
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_PHI3:
-                {
-                    const int64_t n_embd_head = n_embd / n_head;
-
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
-                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
-
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
-                        layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0);
-
-                        layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                    }
-                } break;
-            case LLM_ARCH_PLAMO:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_GPT2:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-                    model.pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD,   "weight"), {n_embd, n_ctx_train}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "bias", i),   {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
-
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_CODESHELL:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
-
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i),   {n_embd, n_ff}, 0);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i),     {n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_ORION:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_INTERNLM2:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        // layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_GEMMA:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                    }
-                } break;
-            case LLM_ARCH_GEMMA2:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
-                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
-                    }
-                } break;
-            case LLM_ARCH_STARCODER2:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-
-                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        // optional bias tensors
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd}, 0);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, 0);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, 0);
-                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-
-                        // optional bias tensors
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP ,  "bias", i), {  n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_MAMBA:
-                {
-                    const int64_t d_conv  = hparams.ssm_d_conv;
-                    const int64_t d_inner = hparams.ssm_d_inner;
-                    const int64_t d_state = hparams.ssm_d_state;
-                    const int64_t dt_rank = hparams.ssm_dt_rank;
-
-                    // only an expansion factor of 2 is supported for now
-                    if (2 * n_embd != d_inner) {
-                        throw std::runtime_error("only an expansion factor of 2 is supported for now");
-                    }
-
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-
-                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    // if output is NULL, init from the input tok embed, duplicated to allow offloading
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        // norm
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0);
-
-                        layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0);
-                        layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0);
-
-                        layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0);
-
-                        layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0);
-                        layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0);
-
-                        // no "weight" suffix for these
-                        layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0);
-                        layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0);
-
-                        // out_proj
-                        layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
-                    }
-                } break;
-            case LLM_ARCH_XVERSE:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_COMMAND_R:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    // init output from the input tok embed
-                    model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        if (n_layer >= 64){
-                            layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0);
-                            layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0);
-                        }
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_COHERE2:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
-                    // init output from the input tok embed
-                    model.output      = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab },
-                                                      llama_model_loader::TENSOR_DUPLICATED);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
-                    }
-                }
-                break;
-            case LLM_ARCH_OLMO:  // adapted from LLM_ARCH_LLAMA with norm params removed
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_OLMO2:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
-                    }
-                } break;
-            case LLM_ARCH_OLMOE:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
-
-                        if (n_expert == 0) {
-                            throw std::runtime_error("n_expert must be > 0");
-                        }
-                        if (n_expert_used == 0) {
-                            throw std::runtime_error("n_expert_used must be > 0");
-                        }
-
-                        // MoE branch
-                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff,   n_expert}, 0);
-                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff,   n_embd, n_expert}, 0);
-                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd, n_ff,   n_expert}, 0);
-                    }
-                } break;
-            case LLM_ARCH_OPENELM:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    // init output from the input tok embed
-                    model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        const int64_t n_head      =   hparams.n_head(i);
-                        const int64_t n_head_qkv  = 2*hparams.n_head_kv(i) + n_head;
-                        const int64_t n_ff        =   hparams.n_ff(i);
-
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head_k}, 0);
-                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
-                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head_k, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_GPTNEOX:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
-
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_ARCTIC:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
-                        layer.ffn_norm_exps = create_tensor(tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd,   n_ff, n_expert}, false);
-                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff, n_embd, n_expert}, 0);
-                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd,   n_ff, n_expert}, 0);
-                    }
-                } break;
-            case LLM_ARCH_DEEPSEEK:
-                {
-
-                    const int64_t n_ff_exp        = hparams.n_ff_exp;
-                    const int64_t n_expert_shared = hparams.n_expert_shared;
-
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        if (i < (int) hparams.n_layer_dense_lead) {
-                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                        } else {
-                            layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
-
-                            if (n_expert == 0) {
-                                throw std::runtime_error("n_expert must be > 0");
-                            }
-                            if (n_expert_used == 0) {
-                                throw std::runtime_error("n_expert_used must be > 0");
-                            }
-
-                            // MoE branch
-                            layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
-                            layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
-                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
-
-                            // Shared expert branch
-                            layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
-                            layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {        n_ff_exp * n_expert_shared, n_embd}, 0);
-                            layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
-                        }
-                    }
-                } break;
-            case LLM_ARCH_DEEPSEEK2:
-                {
-                    const bool is_lite = (hparams.n_layer == 27);
-
-                    const int64_t n_embd_head_qk_rope = hparams.n_rot;
-                    const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
-
-                    const int64_t q_lora_rank  = hparams.n_lora_q;
-                    const int64_t kv_lora_rank = hparams.n_lora_kv;
-
-                    const int64_t n_ff_exp        = hparams.n_ff_exp;
-                    const int64_t n_expert_shared = hparams.n_expert_shared;
-
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        if (!is_lite) {
-                            layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0);
-                        }
-
-                        layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0);
-
-                        if (!is_lite) {
-                            layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0);
-                            layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0);
-                        } else {
-                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        }
-
-                        layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
-                        layer.wkv_b     = create_tensor(tn(LLM_TENSOR_ATTN_KV_B,     "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
-                        layer.wo        = create_tensor(tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {              n_head * (                      n_embd_head_v), n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        if (i < (int) hparams.n_layer_dense_lead) {
-                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                        } else {
-                            layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
-                            layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                            if (n_expert == 0) {
-                                throw std::runtime_error("n_expert must be > 0");
-                            }
-                            if (n_expert_used == 0) {
-                                throw std::runtime_error("n_expert_used must be > 0");
-                            }
-
-                            // MoE branch
-                            layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
-                            layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
-                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
-
-                            // Shared expert branch
-                            layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
-                            layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {        n_ff_exp * n_expert_shared, n_embd}, 0);
-                            layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
-                        }
-                    }
-                } break;
-            case LLM_ARCH_BITNET:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm     = create_tensor(tn(LLM_TENSOR_ATTN_NORM,     "weight", i), {n_embd}, 0);
-                        layer.attn_sub_norm = create_tensor(tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq       = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wq_scale = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.wk       = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wk_scale = create_tensor(tn(LLM_TENSOR_ATTN_K,   "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.wv       = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv_scale = create_tensor(tn(LLM_TENSOR_ATTN_V,   "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.wo       = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.wo_scale = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_norm     = create_tensor(tn(LLM_TENSOR_FFN_NORM,     "weight", i), {n_embd}, 0);
-                        layer.ffn_sub_norm = create_tensor(tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}, 0);
-
-                        layer.ffn_gate       = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_gate_scale = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down       = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_scale = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_up         = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_up_scale   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "scale",  i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    }
-                } break;
-            case LLM_ARCH_T5:
-                {
-                    const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts;
-
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm     = create_tensor(tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd}, 0);
-
-                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm_enc  = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM,  "weight", i), {n_embd}, 0);
-                        layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
-
-                        layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd,   n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up_enc   = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-
-                        layer.attn_norm  = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM,  "weight", i), {n_embd}, 0);
-                        layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_DEC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_DEC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_DEC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
-
-                        layer.attn_norm_cross  = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM,  "weight", i), {n_embd}, 0);
-                        // this tensor seems to be unused in HF transformers implementation
-                        layer.attn_rel_b_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wq_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wk_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd,   n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_DEC_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_T5ENCODER:
-                {
-                    const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts;
-
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm_enc  = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM,  "weight", i), {n_embd}, 0);
-                        layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0);
-
-                        layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd,   n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up_enc   = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_JAIS:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM,   "bias", i),   {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
-
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd}, 0);
-
-                        layer.ffn_gate   = create_tensor(tn(LLM_TENSOR_FFN_GATE,   "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE,   "bias", i),   {n_ff}, 0);
-
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_CHATGLM:
-                {
-                    model.tok_embd   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
-                        layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa}, 0);
-
-                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff * 2}, 0);
-
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
-                    }
-                } break;
-            case LLM_ARCH_NEMOTRON:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
-                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        // optional bias tensors
-                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0);
-
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-
-                        // optional MLP bias
-                        layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    }
-                } break;
-            case LLM_ARCH_EXAONE:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
-
-                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM,   "weight", i), {n_embd}, 0);
-                        layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
-                        layer.ffn_gate   = create_tensor(tn(LLM_TENSOR_FFN_GATE,   "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down   = create_tensor(tn(LLM_TENSOR_FFN_DOWN,   "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up     = create_tensor(tn(LLM_TENSOR_FFN_UP,     "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_RWKV6:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // Block 0, LN0
-                    model.tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
-                    model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0);
-                    model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
-
-                    const int time_mix_extra_dim = hparams.time_mix_extra_dim;
-                    const int time_decay_extra_dim = hparams.time_decay_extra_dim;
-                    const int head_size = hparams.wkv_head_size;
-                    const int attn_hidden_size = n_embd;
-                    const int ffn_size = hparams.n_ff_arr[0];
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd}, 0);
-
-                        layer.attn_norm_2   = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0);
-                        layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i),   {n_embd}, 0);
-
-                        layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0);
-                        layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0);
-
-                        layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0);
-                        layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, 0);
-                        layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
-                        layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, 0);
-                        layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0);
-                        layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, 0);
-
-                        layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0);
-                        layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0);
-                        layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0);
-                        layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0);
-                        layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0);
-                        layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0);
-                        layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0);
-                        layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0);
-
-                        layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0);
-                        layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0);
-                        layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0);
-
-                        layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0);
-                        layer.channel_mix_lerp_r = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0);
-
-                        layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0);
-                        layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0);
-                        layer.channel_mix_receptance = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "weight", i), {n_embd, n_embd}, 0);
-                    }
-
-                } break;
-            case LLM_ARCH_CHAMELEON:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
-
-                    // output
-                    model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                    model.output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                    // if output is NULL, init from the input tok embed
-                    if (model.output == NULL) {
-                        model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
-                    }
-
-                    for (int i = 0; i < n_layer; ++i) {
-                        auto & layer = model.layers[i];
-
-                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0);
-                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0);
-                        layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i),  {n_embd_head_k, n_head}, llama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i),  {n_embd_head_k, n_head_kv}, llama_model_loader::TENSOR_NOT_REQUIRED);
-
-                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
-                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
-
-                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
-
-                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
-                    }
-                } break;
-            case LLM_ARCH_WAVTOKENIZER_DEC:
-                {
-                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd_features, n_vocab}, 0);
-
-                    model.conv1d   = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd_features, hparams.posnet.n_embd}, 0);
-                    model.conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"),   {1, hparams.posnet.n_embd}, 0);
-
-                    // posnet
-                    {
-                        const int64_t n_embd = hparams.posnet.n_embd;
-
-                        for (uint32_t i = 0; i < hparams.posnet.n_layer; ++i) {
-                            auto & layer = model.layers[i].posnet;
-
-                            // posnet:
-                            //
-                            //  - resnet
-                            //  - resnet
-                            //  - attn
-                            //  - resnet
-                            //  - resnet
-                            //  - norm
-                            //
-                            switch (i) {
-                                case 0:
-                                case 1:
-                                case 3:
-                                case 4:
-                                    {
-                                        layer.norm1   = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", i), {1, n_embd}, 0);
-                                        layer.norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias",   i), {1, n_embd}, 0);
-
-                                        layer.conv1   = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "weight", i), {3, n_embd, n_embd}, 0);
-                                        layer.conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias",   i), {1, n_embd}, 0);
-
-                                        layer.norm2   = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", i), {1, n_embd}, 0);
-                                        layer.norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias",   i), {1, n_embd}, 0);
-
-                                        layer.conv2   = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "weight", i), {3, n_embd, n_embd}, 0);
-                                        layer.conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias",   i), {1, n_embd}, 0);
-                                    } break;
-                                case 2:
-                                    {
-                                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0);
-                                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias",   i), {1, n_embd}, 0);
-
-                                        layer.attn_q      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q,    "weight", i), {1, n_embd, n_embd}, 0);
-                                        layer.attn_q_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q,    "bias",   i), {1, n_embd}, 0);
-
-                                        layer.attn_k      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K,    "weight", i), {1, n_embd, n_embd}, 0);
-                                        layer.attn_k_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K,    "bias",   i), {1, n_embd}, 0);
-
-                                        layer.attn_v      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V,    "weight", i), {1, n_embd, n_embd}, 0);
-                                        layer.attn_v_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V,    "bias",   i), {1, n_embd}, 0);
-
-                                        layer.attn_o      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT,  "weight", i), {1, n_embd, n_embd}, 0);
-                                        layer.attn_o_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT,  "bias",   i), {1, n_embd}, 0);
-                                    } break;
-                                case 5:
-                                    {
-                                        layer.norm   = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0);
-                                        layer.norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias",   i), {1, n_embd}, 0);
-                                    } break;
-                                default: GGML_ABORT("unknown posnet layer");
-                            };
-                        }
-                    }
-
-                    GGML_ASSERT(hparams.posnet.n_embd == hparams.convnext.n_embd);
-
-                    model.tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {hparams.posnet.n_embd}, 0);
-                    model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"),   {hparams.posnet.n_embd}, 0);
-
-                    // convnext
-                    {
-                        const int64_t n_embd = hparams.convnext.n_embd;
-
-                        for (uint32_t i = 0; i < hparams.convnext.n_layer; ++i) {
-                            auto & layer = model.layers[i].convnext;
-
-                            layer.dw     = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW,    "weight", i), {7, 1, n_embd}, 0);
-                            layer.dw_b   = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW,    "bias",   i), {1, n_embd}, 0);
-
-                            layer.norm   = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM,  "weight", i), {n_embd}, 0);
-                            layer.norm_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM,  "bias",   i), {n_embd}, 0);
-
-                            layer.pw1    = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1,   "weight", i), {n_embd, n_ff}, 0);
-                            layer.pw1_b  = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1,   "bias",   i), {n_ff}, 0);
-
-                            layer.pw2    = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2,   "weight", i), {n_ff, n_embd}, 0);
-                            layer.pw2_b  = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2,   "bias",   i), {n_embd}, 0);
-
-                            layer.gamma  = create_tensor(tn(LLM_TENSOR_CONVNEXT_GAMMA, "weight", i), {n_embd}, 0);
-                        }
-
-                        // output
-                        model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
-                        model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
-                    }
-
-                    model.output   = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0);
-                    model.output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"),   {n_embd}, 0);
-                } break;
-            default:
-                throw std::runtime_error("unknown architecture");
-        }
-
-        if (n_moved_tensors > 0) {
-            LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %d others) cannot be used with preferred buffer type %s, using %s instead\n",
-                __func__, first_moved_tensor->name, ggml_type_name(first_moved_tensor->type), n_moved_tensors - 1,
-                ggml_backend_buft_name(first_moved_from_buft), ggml_backend_buft_name(first_moved_to_buft));
-        }
-    }
-
-    ml.done_getting_tensors();
-
-    ml.init_mappings(true, use_mlock ? &model.mlock_mmaps : nullptr);
-    model.mappings.reserve(ml.mappings.size());
-
-    // create the backend buffers
-    std::vector> ctx_bufs;
-    ctx_bufs.reserve(ctx_map.size());
-
-    // Ensure we have enough capacity for the maximum backend buffer we will potentially create
-    const size_t n_max_backend_buffer = ctx_map.size() * ml.files.size();
-    model.bufs.reserve(n_max_backend_buffer);
-
-    for (auto & it : ctx_map) {
-        ggml_backend_buffer_type_t buft = it.first;
-        ggml_context * ctx              = it.second;
-
-        // skip contexts without tensors
-        if (ggml_get_first_tensor(ctx) == nullptr) {
-            continue;
-        }
-
-        llama_buf_map bufs;
-        bufs.reserve(n_max_backend_buffer);
-
-        // check if it is possible to use buffer_from_host_ptr with this buffer type
-        ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft);
-        if (!dev) {
-            // FIXME: workaround for CPU backend buft having a NULL device
-            dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
-        }
-        ggml_backend_dev_props props;
-        ggml_backend_dev_get_props(dev, &props);
-        bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr;
-        bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev);
-
-        if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) {
-            for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
-                // only the mmap region containing the tensors in the model is mapped to the backend buffer
-                // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, then we could just use metal for all layers
-                // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size
-                void * addr = nullptr;
-                size_t first, last; // NOLINT
-                ml.get_mapping_range(&first, &last, &addr, idx, ctx);
-                if (first >= last) {
-                    continue;
-                }
-                const size_t max_size = ggml_get_max_tensor_size(ctx);
-                ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size);
-                if (buf == nullptr) {
-                    throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
-                }
-                model.bufs.emplace_back(buf);
-                bufs.emplace(idx, buf);
-            }
-        }
-        else {
-            ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
-            if (buf == nullptr) {
-                throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
-            }
-            model.bufs.emplace_back(buf);
-            if (use_mlock && ggml_backend_buffer_is_host(buf)) {
-                model.mlock_bufs.emplace_back(new llama_mlock);
-                auto & mlock_buf = model.mlock_bufs.back();
-                mlock_buf->init   (ggml_backend_buffer_get_base(buf));
-                mlock_buf->grow_to(ggml_backend_buffer_get_size(buf));
-            }
-            for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
-                bufs.emplace(idx, buf);
-            }
-        }
-
-        if (bufs.empty()) {
-            throw std::runtime_error("failed to allocate buffer");
-        }
-
-        for (auto & buf : bufs) {
-            // indicate that this buffer contains weights
-            // this is used by ggml_backend_sched to improve op scheduling: ops that use a weight are preferably scheduled to the backend that contains the weight
-            ggml_backend_buffer_set_usage(buf.second, GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
-        }
-
-        ctx_bufs.emplace_back(ctx, bufs);
-    }
-
-    if (llama_supports_gpu_offload()) {
-        const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
-
-        LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_gpu);
-        if (n_gpu_layers > (int) hparams.n_layer) {
-            LLAMA_LOG_INFO("%s: offloading output layer to GPU\n", __func__);
-        }
-
-        const int max_backend_supported_layers = hparams.n_layer + 1;
-        const int max_offloadable_layers       = hparams.n_layer + 1;
-
-        LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers);
-    }
-
-    // print memory requirements per buffer type
-    for (auto & buf : model.bufs) {
-        LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0);
-    }
-
-    // populate tensors_by_name
-    for (auto & ctx : model.ctxs) {
-        for (auto * cur = ggml_get_first_tensor(ctx.get()); cur != NULL; cur = ggml_get_next_tensor(ctx.get(), cur)) {
-            model.tensors_by_name.emplace_back(ggml_get_name(cur), cur);
-        }
-    }
-
-    // load tensor data
-    for (auto & it : ctx_bufs) {
-        ggml_context * ctx = it.first;
-        auto & bufs = it.second;
-        if (!ml.load_all_data(ctx, bufs, use_mlock ? &model.mlock_mmaps : NULL, progress_callback, progress_callback_user_data)) {
-            return false;
-        }
-    }
-
-    if (use_mmap_buffer) {
-        for (auto & mapping : ml.mappings) {
-            model.mappings.emplace_back(std::move(mapping));
-        }
-    }
-
-    return true;
-}
-
 // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
 static int llama_model_load(const std::string & fname, llama_model & model, llama_model_params & params) {
-    model.t_start_us = ggml_time_us();
+    // loading time will be recalculated after the first eval, so
+    // we take page faults deferred by mmap() into consideration
+    model.t_load_us = 0;
+    time_meas tm(model.t_load_us);
+
+    model.t_start_us = tm.t_start_us;
 
     try {
         llama_model_loader ml(fname, params.use_mmap, params.check_tensors, params.kv_overrides);
 
+        ml.print_info();
+
         model.hparams.vocab_only = params.vocab_only;
 
         try {
-            llm_load_arch(ml, model);
+            model.load_arch(ml);
         } catch(const std::exception & e) {
             throw std::runtime_error("error loading model architecture: " + std::string(e.what()));
         }
         try {
-            llm_load_hparams(ml, model);
+            model.load_hparams(ml);
         } catch(const std::exception & e) {
             throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what()));
         }
         try {
-            llm_load_vocab(ml, model);
+            model.load_vocab(ml);
         } catch(const std::exception & e) {
             throw std::runtime_error("error loading model vocabulary: " + std::string(e.what()));
         }
 
-        llm_load_stats(ml, model);
-        llm_load_print_meta(ml, model);
-
-        if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE &&
-            model.hparams.n_vocab != model.vocab.id_to_token.size()) {
-            throw std::runtime_error("vocab size mismatch");
-        }
+        model.load_stats(ml);
+        model.print_info();
 
         if (params.vocab_only) {
             LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
             return 0;
         }
 
-        if (!llm_load_tensors(
-            ml, model, params.n_gpu_layers, params.split_mode,  params.main_gpu, params.tensor_split, params.use_mlock,
-            params.progress_callback, params.progress_callback_user_data
-        )) {
+        if (!model.load_tensors(ml)) {
             return -2;
         }
     } catch (const std::exception & err) {
@@ -2504,10 +78,6 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
         return -1;
     }
 
-    // loading time will be recalculate after the first eval, so
-    // we take page faults deferred by mmap() into consideration
-    model.t_load_us = ggml_time_us() - model.t_start_us;
-
     return 0;
 }
 
@@ -2553,6 +123,21 @@ static struct ggml_tensor * llm_build_inp_embd(
         ggml_set_input(lctx.inp_tokens);
 
         inpL = ggml_get_rows(ctx, tok_embd, lctx.inp_tokens);
+
+        // apply lora for embedding tokens if needed
+        for (auto & it : lctx.lora) {
+            struct llama_adapter_lora_weight * lw = it.first->get_weight(tok_embd);
+            if (lw == nullptr) {
+                continue;
+            }
+            const float adapter_scale = it.second;
+            const float scale = lw->get_scale(it.first->alpha, adapter_scale);
+            struct ggml_tensor * inpL_delta = ggml_scale(ctx, ggml_mul_mat(
+                ctx, lw->b, // non-transposed lora_b
+                ggml_get_rows(ctx, lw->a, lctx.inp_tokens)
+            ), scale);
+            inpL = ggml_add(ctx, inpL, inpL_delta);
+        }
     } else {
         lctx.inp_embd = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, ubatch.n_tokens);
         inpL = lctx.inp_embd;
@@ -2620,17 +205,16 @@ static struct ggml_tensor * llm_build_lora_mm(
           struct ggml_tensor * w,
           struct ggml_tensor * cur) {
     struct ggml_tensor * res = ggml_mul_mat(ctx0, w, cur);
-    for (auto & it : lctx.lora_adapters) {
-        struct llama_lora_weight * lora = it.first->get_weight(w);
-        if (lora == nullptr) {
+    for (auto & it : lctx.lora) {
+        struct llama_adapter_lora_weight * lw = it.first->get_weight(w);
+        if (lw == nullptr) {
             continue;
         }
-        const float alpha = it.first->alpha;
-        const float rank  = (float) lora->b->ne[0];
-        const float scale = alpha ? it.second * alpha / rank : it.second;
+        const float adapter_scale = it.second;
+        const float scale = lw->get_scale(it.first->alpha, adapter_scale);
         struct ggml_tensor * ab_cur = ggml_mul_mat(
-            ctx0, lora->b,
-            ggml_mul_mat(ctx0, lora->a, cur)
+            ctx0, lw->b,
+            ggml_mul_mat(ctx0, lw->a, cur)
         );
         ab_cur = ggml_scale(ctx0, ab_cur, scale);
         res = ggml_add(ctx0, res, ab_cur);
@@ -2646,17 +230,17 @@ static struct ggml_tensor * llm_build_lora_mm_id(
           struct ggml_tensor * cur, // struct ggml_tensor * b
           struct ggml_tensor * ids) {
     struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
-    for (auto & it : lctx.lora_adapters) {
-        struct llama_lora_weight * lora = it.first->get_weight(w);
-        if (lora == nullptr) {
+    for (auto & it : lctx.lora) {
+        struct llama_adapter_lora_weight * lw = it.first->get_weight(w);
+        if (lw == nullptr) {
             continue;
         }
         const float alpha = it.first->alpha;
-        const float rank  = (float) lora->b->ne[0];
+        const float rank  = (float) lw->b->ne[0];
         const float scale = alpha ? it.second * alpha / rank : it.second;
         struct ggml_tensor * ab_cur = ggml_mul_mat_id(
-            ctx0, lora->b,
-            ggml_mul_mat_id(ctx0, lora->a, cur, ids),
+            ctx0, lw->b,
+            ggml_mul_mat_id(ctx0, lw->a, cur, ids),
             ids
         );
         ab_cur = ggml_scale(ctx0, ab_cur, scale);
@@ -3287,16 +871,20 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
         const struct llama_layer * layer,
         struct ggml_tensor * cur,
         struct ggml_tensor * x_prev,
-        struct ggml_tensor ** wkv_state) {
+        struct ggml_tensor ** wkv_state,
+        size_t wkv_head_size,
+        size_t head_count_kv) {
     size_t n_embd       = cur->ne[0];
     size_t n_seq_tokens = cur->ne[1];
     size_t n_seqs       = cur->ne[2];
 
-    size_t head_size  = layer->time_mix_first->ne[0];
-    size_t head_count = layer->time_mix_first->ne[1];
+    size_t head_size  = wkv_head_size;
+    size_t head_count = n_embd / head_size;
 
     size_t n_tokens = n_seqs * n_seq_tokens;
 
+    bool is_qrwkv = layer->time_mix_first == nullptr;
+
     struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
 
     sx  = ggml_reshape_2d(ctx, sx,  n_embd, n_tokens);
@@ -3325,69 +913,64 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
         xxx
     );
 
-    struct ggml_tensor *mw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
-    struct ggml_tensor *mk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
-    struct ggml_tensor *mv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
-    struct ggml_tensor *mr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
-    struct ggml_tensor *mg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
-
-    struct ggml_tensor * xw = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mw, layer->time_mix_lerp_w),
-            sx
-        ),
-        cur
-    );
+    struct ggml_tensor *xw, *xk, *xv, *xr, *xg;
+    if (layer->time_mix_lerp_fused) {
+        // fusing these weights makes some performance improvement
+        sx  = ggml_reshape_3d(ctx, sx,  n_embd, 1, n_tokens);
+        cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
+        xxx = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xxx, layer->time_mix_lerp_fused), sx), cur);
+        xw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
+        xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
+        xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
+        xr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
+        xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
+    } else {
+        // for backward compatibility
+        xw = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], 0);
+        xk = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
+        xv = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
+        xr = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
+        xg = ggml_view_2d(ctx, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
 
-    struct ggml_tensor * xk = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mk, layer->time_mix_lerp_k),
-            sx
-        ),
-        cur
-    );
+        xw = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xw, layer->time_mix_lerp_w), sx), cur);
+        xk = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xk, layer->time_mix_lerp_k), sx), cur);
+        xv = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xv, layer->time_mix_lerp_v), sx), cur);
+        xr = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xr, layer->time_mix_lerp_r), sx), cur);
+        xg = ggml_add(ctx, ggml_mul(ctx, ggml_add(ctx, xg, layer->time_mix_lerp_g), sx), cur);
+    }
 
-    struct ggml_tensor * xv = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mv, layer->time_mix_lerp_v),
-            sx
-        ),
-        cur
-    );
+    struct ggml_tensor * r = llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr);
+    struct ggml_tensor * k = llm_build_lora_mm(lctx, ctx, layer->time_mix_key,        xk);
+    struct ggml_tensor * v = llm_build_lora_mm(lctx, ctx, layer->time_mix_value,      xv);
+    if (layer->time_mix_receptance_b) {
+        r = ggml_add(ctx, r, layer->time_mix_receptance_b);
+    }
+    if (layer->time_mix_key_b) {
+        k = ggml_add(ctx, k, layer->time_mix_key_b);
+    }
+    if (layer->time_mix_value_b) {
+        v = ggml_add(ctx, v, layer->time_mix_value_b);
+    }
 
-    struct ggml_tensor * xr = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mr, layer->time_mix_lerp_r),
-            sx
-        ),
-        cur
-    );
+    struct ggml_tensor * g = llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg);
+    if (is_qrwkv) {
+        g = ggml_sigmoid(ctx, g);
+    } else {
+        g = ggml_silu(ctx, g);
+    }
 
-    struct ggml_tensor * xg = ggml_add(
-        ctx,
-        ggml_mul(
-            ctx,
-            ggml_add(ctx, mg, layer->time_mix_lerp_g),
-            sx
-        ),
-        cur
-    );
+    if (head_count_kv != head_count) {
+        GGML_ASSERT(head_count % head_count_kv == 0);
+        k = ggml_reshape_4d(ctx, k, head_size, 1, head_count_kv, n_tokens);
+        v = ggml_reshape_4d(ctx, v, head_size, 1, head_count_kv, n_tokens);
+        struct ggml_tensor * tmp = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_size, head_count / head_count_kv, head_count_kv, n_tokens);
+        k = ggml_repeat(ctx, k, tmp);
+        v = ggml_repeat(ctx, v, tmp);
+    }
 
-    struct ggml_tensor * r = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_receptance, xr), head_size, 1,         head_count, n_tokens);
-    struct ggml_tensor * k = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_key,        xk), 1,         head_size, head_count, n_tokens);
-    struct ggml_tensor * v = ggml_reshape_4d(ctx, llm_build_lora_mm(lctx, ctx, layer->time_mix_value,      xv), head_size, 1,         head_count, n_tokens);
-    struct ggml_tensor * g = ggml_silu(
-        ctx,
-        llm_build_lora_mm(lctx, ctx, layer->time_mix_gate, xg)
-    );
+    k = ggml_reshape_3d(ctx, k, head_size, head_count, n_tokens);
+    v = ggml_reshape_3d(ctx, v, head_size, head_count, n_tokens);
+    r = ggml_reshape_3d(ctx, r, head_size, head_count, n_tokens);
 
     struct ggml_tensor * w = ggml_mul_mat(
         ctx,
@@ -3398,25 +981,35 @@ static struct ggml_tensor * llm_build_rwkv6_time_mix(
         )
     );
 
-    w = ggml_add(ctx, w, ggml_reshape_1d(ctx, layer->time_mix_decay, n_embd));
+    w = ggml_add(ctx, w, layer->time_mix_decay);
     w = ggml_exp(ctx, ggml_neg(ctx, ggml_exp(ctx, w)));
-    w = ggml_reshape_4d(ctx, w, 1, head_size, head_count, n_tokens);
+    w = ggml_reshape_3d(ctx, w, head_size, head_count, n_tokens);
 
-    k = ggml_transpose(ctx, k);
-    v = ggml_transpose(ctx, v);
-    r = ggml_transpose(ctx, r);
+    if (is_qrwkv) {
+        // k = k * (1 - w)
+        k = ggml_sub(ctx, k, ggml_mul(ctx, k, w));
+    }
 
-    struct ggml_tensor * wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
+    struct ggml_tensor * wkv_output;
+    if (!layer->time_mix_first) {
+        wkv_output = ggml_gated_linear_attn(ctx, k, v, r, w, *wkv_state, pow(head_size, -0.5f));
+    } else {
+        wkv_output = ggml_rwkv_wkv6(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
+    }
     cur = ggml_view_1d(ctx, wkv_output, n_embd * n_tokens, 0);
     *wkv_state = ggml_view_1d(ctx, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));
 
-    // group norm with head_count groups
-    cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens);
-    cur = ggml_norm(ctx, cur, 64e-5f);
+    if (!is_qrwkv) {
+        // group norm with head_count groups
+        cur = ggml_reshape_3d(ctx, cur, n_embd / head_count, head_count, n_tokens);
+        cur = ggml_norm(ctx, cur, 64e-5f);
 
-    // Convert back to regular vectors.
-    cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
-    cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
+        // Convert back to regular vectors.
+        cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
+        cur = ggml_add(ctx, ggml_mul(ctx, cur, layer->time_mix_ln), layer->time_mix_ln_b);
+    } else {
+        cur = ggml_reshape_2d(ctx, cur, n_embd, n_tokens);
+    }
 
     cur = ggml_mul(ctx, cur, g);
     cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur);
@@ -3572,7 +1165,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_k_shift() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         GGML_ASSERT(kv_self.size == n_ctx);
 
@@ -3622,7 +1215,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_defrag(const std::vector & ids) {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         for (uint32_t i = 0; i < ids.size(); ++i) {
             const uint32_t id = ids[i];
@@ -3881,7 +1474,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_llama() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -3975,6 +1568,7 @@ struct llm_build_context {
 
             // feed-forward network
             if (model.layers[il].ffn_gate_inp == nullptr) {
+
                 cur = llm_build_norm(ctx0, ffn_inp, hparams,
                         model.layers[il].ffn_norm, NULL,
                         LLM_NORM_RMS, cb, il);
@@ -4046,7 +1640,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_deci() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -4207,7 +1801,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_baichuan() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -4219,7 +1813,7 @@ struct llm_build_context {
         inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
-        struct ggml_tensor * inp_pos = model.type == MODEL_7B ? build_inp_pos() : nullptr;
+        struct ggml_tensor * inp_pos = model.type == LLM_TYPE_7B ? build_inp_pos() : nullptr;
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
@@ -4244,7 +1838,7 @@ struct llm_build_context {
                 cb(Vcur, "Vcur", il);
 
                 switch (model.type) {
-                    case MODEL_7B:
+                    case LLM_TYPE_7B:
                         Qcur = ggml_rope_ext(
                             ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
                             n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -4256,7 +1850,7 @@ struct llm_build_context {
                             ext_factor, attn_factor, beta_fast, beta_slow
                         );
                         break;
-                    case MODEL_13B:
+                    case LLM_TYPE_13B:
                         Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens);
                         Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd/n_head, n_head, n_tokens);
                         break;
@@ -4322,7 +1916,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_xverse() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -4425,7 +2019,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_falcon() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -4545,7 +2139,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_grok() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -4704,7 +2298,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_dbrx() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -4832,7 +2426,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_starcoder() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -4936,7 +2530,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_refact() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -5030,7 +2624,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_bert() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -5224,7 +2818,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_bloom() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -5325,7 +2919,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_mpt() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -5615,7 +3209,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_qwen() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -5727,7 +3321,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_qwen2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -5839,7 +3433,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_qwen2vl() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
         GGML_ASSERT(n_embd_head == hparams.n_rot);
@@ -5957,7 +3551,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_qwen2moe() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -6105,7 +3699,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_phi2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -6226,7 +3820,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_phi3() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
@@ -6259,7 +3853,7 @@ struct llm_build_context {
 
                 struct ggml_tensor* attn_norm_output = llm_build_norm(ctx0, inpL, hparams,
                     model.layers[il].attn_norm,
-                    NULL,
+                    model.layers[il].attn_norm_b,
                     LLM_NORM_RMS, cb, il);
                 cb(attn_norm_output, "attn_norm", il);
 
@@ -6274,8 +3868,7 @@ struct llm_build_context {
                     Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd)));
                     Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd)));
                     Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)));
-                }
-                else {
+                } else {
                     Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
                     Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
                     Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
@@ -6319,14 +3912,12 @@ struct llm_build_context {
             residual = cur;
 
             cur = llm_build_norm(ctx0, cur, hparams,
-                model.layers[il].ffn_norm, NULL,
+                model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
                 LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm", il);
 
-            // FF
-            // special-case: the up and gate tensors are merged into a single tensor
-            // TOOD: support into llm_build_ffn
-            {
+            // feed-forward network
+            if (model.layers[il].ffn_gate_inp == nullptr) {
                 cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   NULL, NULL,
                         NULL,                      NULL, NULL,
@@ -6334,6 +3925,20 @@ struct llm_build_context {
                         NULL,
                         LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
+            } else {
+                // MoE branch
+                cur = llm_build_moe_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_gate_inp,
+                        model.layers[il].ffn_up_exps,
+                        model.layers[il].ffn_gate_exps,
+                        model.layers[il].ffn_down_exps,
+                        nullptr,
+                        n_expert, n_expert_used,
+                        LLM_FFN_SILU, true,
+                        false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                        cb, il);
+                cb(cur, "ffn_moe_out", il);
             }
 
             cur = ggml_add(ctx0, residual, cur);
@@ -6346,11 +3951,16 @@ struct llm_build_context {
 
         cur = llm_build_norm(ctx0, inpL, hparams,
             model.output_norm,
-            NULL,
+            model.output_norm_b,
             LLM_NORM_RMS, cb, -1);
         cb(cur, "result_norm", -1);
 
         cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        if (model.output_b != nullptr) {
+            cb(cur, "result_output_no_bias", -1);
+            cur = ggml_add(ctx0, cur, model.output_b);
+        }
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -6464,7 +4074,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gpt2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -6569,7 +4179,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_codeshell() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -6680,7 +4290,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_orion() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -6798,7 +4408,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_internlm2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -6916,7 +4526,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_minicpm3() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         //TODO: if the model varies, these parameters need to be read from the model
         const int64_t n_embd_base = 256;
@@ -7125,7 +4735,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gemma() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head_k = hparams.n_embd_head_k;
 
@@ -7233,7 +4843,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gemma2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head_k = hparams.n_embd_head_k;
 
@@ -7283,9 +4893,9 @@ struct llm_build_context {
 
                 // ref: https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
                 switch (model.type) {
-                    case llm_type::MODEL_2B:
-                    case llm_type::MODEL_9B:  Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));   break;
-                    case llm_type::MODEL_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
+                    case LLM_TYPE_2B:
+                    case LLM_TYPE_9B:  Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));   break;
+                    case LLM_TYPE_27B: Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); break;
                     default: GGML_ABORT("fatal error");
                 };
                 cb(Qcur, "Qcur_scaled", il);
@@ -7369,7 +4979,7 @@ struct llm_build_context {
 
 
     struct ggml_cgraph * build_starcoder2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -7488,7 +5098,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_mamba() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
@@ -7543,7 +5153,7 @@ struct llm_build_context {
 
     struct ggml_cgraph * build_command_r() {
 
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -7691,7 +5301,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_cohere2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -7828,7 +5438,7 @@ struct llm_build_context {
     //   * removed bias
     //   * removed MoE
     struct ggml_cgraph * build_olmo() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -7952,7 +5562,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_olmo2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -8080,7 +5690,7 @@ struct llm_build_context {
     //   * removed bias
     //   * added q, k norm
     struct ggml_cgraph * build_olmoe() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -8206,7 +5816,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_openelm() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -8331,7 +5941,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_gptneox() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -8473,7 +6083,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_arctic() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -8607,7 +6217,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_deepseek() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -8764,7 +6374,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_deepseek2() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -8994,7 +6604,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_bitnet() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -9145,7 +6755,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_t5_enc() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -9277,7 +6887,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_t5_dec() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -9482,7 +7092,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_jais() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -9574,7 +7184,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_chatglm() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
@@ -9688,7 +7298,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_nemotron() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         const int64_t n_embd_head = hparams.n_embd_head_v;
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -9809,7 +7419,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_exaone() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -9936,7 +7546,7 @@ struct llm_build_context {
     }
 
     ggml_cgraph * build_rwkv6() {
-        ggml_cgraph *gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // Token shift state dimensions should be 2 * n_emb
         GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2);
@@ -9981,7 +7591,7 @@ struct llm_build_context {
                 1
             );
 
-            cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states));
+            cur = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.wkv_head_size, n_embd / hparams.wkv_head_size));
             ggml_build_forward_expand(gf, cur);
             ggml_build_forward_expand(
                 gf,
@@ -10048,6 +7658,118 @@ struct llm_build_context {
         return gf;
     }
 
+    // ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
+    ggml_cgraph * build_rwkv6qwen2() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
+
+        GGML_ASSERT(n_embd == hparams.n_embd_k_s());
+
+        const int64_t n_seqs = ubatch.n_seqs;
+        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+        const int64_t n_tokens = ubatch.n_tokens;
+        GGML_ASSERT(n_seqs != 0);
+        GGML_ASSERT(ubatch.equal_seqs);
+        GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+        struct ggml_tensor * state_copy = build_inp_s_copy();
+        struct ggml_tensor * state_mask = build_inp_s_mask();
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
+
+        for (int il = 0; il < n_layer; ++il) {
+            const llama_layer * layer = &model.layers[il];
+
+            // (ab)using the KV cache to store the states
+            struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0,
+                    gf, kv_self.k_l[il], state_copy, state_mask,
+                    hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs);
+            struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0,
+                    gf, kv_self.v_l[il], state_copy, state_mask,
+                    hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs);
+
+            cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
+            token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 1, n_seqs);
+
+            struct ggml_tensor * x_norm_att = llm_build_norm(ctx0, cur, hparams, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, cb, il);
+            struct ggml_tensor * x_prev = ggml_concat(
+                ctx0,
+                token_shift,
+                ggml_view_3d(ctx0, x_norm_att, n_embd, n_seq_tokens - 1, n_seqs, x_norm_att->nb[1], x_norm_att->nb[2], 0),
+                1
+            );
+
+            ggml_build_forward_expand(
+                gf,
+                ggml_cpy(
+                    ctx0,
+                    wkv_states,
+                    ggml_view_1d(
+                        ctx0,
+                        kv_self.v_l[il],
+                        hparams.n_embd_v_s() * n_seqs,
+                        hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
+                    )
+                )
+            );
+
+            struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, llm_build_rwkv6_time_mix(lctx, ctx0, layer, x_norm_att, x_prev, &wkv_states, hparams.wkv_head_size, hparams.n_head_kv()));
+            ggml_build_forward_expand(gf, ffn_inp);
+            ggml_build_forward_expand(
+                gf,
+                ggml_cpy(
+                    ctx0,
+                    wkv_states,
+                    ggml_view_1d(
+                        ctx0,
+                        kv_self.v_l[il],
+                        hparams.n_embd_v_s() * n_seqs,
+                        hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
+                    )
+                )
+            );
+
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, cb, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = llm_build_ffn(ctx0, lctx, cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    model.layers[il].ffn_gate, NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = lctx.cvec.apply_to(ctx0, cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+        struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+        cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
+        cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+
+        cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM_RMS, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
     // ref: https://github.com/facebookresearch/chameleon
     // based on the original build_llama() function, changes:
     //   * qk-norm
@@ -10055,7 +7777,7 @@ struct llm_build_context {
     //   * removed bias
     //   * removed MoE
     struct ggml_cgraph * build_chameleon() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         // mutable variable, needed during the last layer of the computation to skip unused tokens
         int32_t n_tokens = this->n_tokens;
@@ -10227,7 +7949,7 @@ struct llm_build_context {
     }
 
     struct ggml_cgraph * build_wavtokenizer_dec() {
-        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false);
 
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
@@ -10436,12 +8158,12 @@ static struct ggml_cgraph * llama_build_graph(
 
         // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
         // FIXME: fix in ggml_backend_sched
-        const bool full_offload = lctx.model.n_gpu_layers > (int)lctx.model.hparams.n_layer;
+        const bool full_offload = lctx.model.params.n_gpu_layers > (int) lctx.model.hparams.n_layer;
         if (ubatch.n_tokens < 32 || full_offload) {
             if (il != -1 && strcmp(name, "norm") == 0) {
-                const auto & dev_layer = lctx.model.dev_layer.at(il);
+                const auto & dev_layer = lctx.model.dev_layer(il);
                 for (auto & backend : lctx.backends) {
-                    if (ggml_backend_get_device(backend.get()) == dev_layer.dev) {
+                    if (ggml_backend_get_device(backend.get()) == dev_layer) {
                         if (ggml_backend_supports_op(backend.get(), cur)) {
                             ggml_backend_sched_set_tensor_backend(lctx.sched.get(), cur, backend.get());
                         }
@@ -10529,6 +8251,7 @@ static struct ggml_cgraph * llama_build_graph(
                 result = llm.build_phi2();
             } break;
         case LLM_ARCH_PHI3:
+        case LLM_ARCH_PHIMOE:
             {
                 result = llm.build_phi3();
             } break;
@@ -10656,6 +8379,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_rwkv6();
             } break;
+        case LLM_ARCH_RWKV6QWEN2:
+            {
+                result = llm.build_rwkv6qwen2();
+            } break;
         case LLM_ARCH_CHAMELEON:
             {
                 result = llm.build_chameleon();
@@ -10735,6 +8462,7 @@ static int llama_decode_impl(
     const uint32_t n_tokens_all = batch.n_tokens;
 
     const auto & model   = lctx.model;
+    const auto & vocab   = model.vocab;
     const auto & hparams = model.hparams;
     const auto & cparams = lctx.cparams;
 
@@ -10742,7 +8470,7 @@ static int llama_decode_impl(
 
     if (batch.token) {
         for (uint32_t i = 0; i < n_tokens_all; ++i) {
-            if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= model.vocab.n_vocab) {
+            if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
                 LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
                 return -1;
             }
@@ -10762,7 +8490,7 @@ static int llama_decode_impl(
     llama_kv_slot_restorer kv_slot_restorer(kv_self);
 
     const int64_t n_embd  = hparams.n_embd;
-    const int64_t n_vocab = hparams.n_vocab;
+    const int64_t n_vocab = vocab.n_tokens();
 
     uint32_t n_outputs = 0;
     uint32_t n_outputs_prev = 0;
@@ -11077,7 +8805,7 @@ static int llama_encode_impl(
 
     if (batch.token) {
         for (uint32_t i = 0; i < n_tokens; ++i) {
-            if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= model.vocab.n_vocab) {
+            if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) {
                 LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
                 return -1;
             }
@@ -11254,9 +8982,9 @@ static void llama_kv_cache_defrag_impl(struct llama_context & lctx) {
     // each move requires 6*n_layer tensors (see build_defrag)
     //   - source view, destination view, copy operation
     //   - x2 for keys and values
-    //const uint32_t max_moves = llama_model_max_nodes(model)/(6*n_layer);
+    //const uint32_t max_moves = model.max_nodes()/(6*n_layer);
     // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
-    const uint32_t max_moves = (llama_model_max_nodes(lctx.model) - 2*n_layer)/(6*n_layer);
+    const uint32_t max_moves = (lctx.model.max_nodes() - 2*n_layer)/(6*n_layer);
 
     // determine which KV cells to move where
     //
@@ -11503,7 +9231,7 @@ static void llama_kv_cache_update_impl(struct llama_context & lctx) {
         // build worst-case graph
         uint32_t n_seqs = 1; // TODO: worst-case number of sequences
         uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
-        llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
+        llama_token token = lctx.model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
         llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
         ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
 
@@ -11515,39 +9243,38 @@ static void llama_kv_cache_update_impl(struct llama_context & lctx) {
     }
 }
 
-int32_t llama_lora_adapter_set(
+int32_t llama_set_adapter_lora(
             struct llama_context * ctx,
-            struct llama_lora_adapter * adapter,
+            struct llama_adapter_lora * adapter,
             float scale) {
-    ctx->lora_adapters[adapter] = scale;
+    ctx->lora[adapter] = scale;
     return 0;
 }
 
-int32_t llama_lora_adapter_remove(
+int32_t llama_rm_adapter_lora(
             struct llama_context * ctx,
-            struct llama_lora_adapter * adapter) {
-    auto pos = ctx->lora_adapters.find(adapter);
-    if (pos != ctx->lora_adapters.end()) {
-        ctx->lora_adapters.erase(pos);
+            struct llama_adapter_lora * adapter) {
+    auto pos = ctx->lora.find(adapter);
+    if (pos != ctx->lora.end()) {
+        ctx->lora.erase(pos);
         return 0;
     }
 
     return -1;
 }
 
-void llama_lora_adapter_clear(struct llama_context * ctx) {
-    ctx->lora_adapters.clear();
+void llama_clear_adapter_lora(struct llama_context * ctx) {
+    ctx->lora.clear();
 }
 
-// TODO: tmp
-int32_t llama_control_vector_apply(
-        struct llama_context * lctx,
+int32_t llama_apply_adapter_cvec(
+        struct llama_context * ctx,
                  const float * data,
                       size_t   len,
                      int32_t   n_embd,
                      int32_t   il_start,
                      int32_t   il_end) {
-    return llama_control_vector_apply(lctx->cvec, lctx->model, data, len, n_embd, il_start, il_end);
+    return ctx->cvec.apply(ctx->model, data, len, n_embd, il_start, il_end);
 }
 
 //
@@ -11658,7 +9385,7 @@ struct llama_model * llama_model_load_from_file(
         struct llama_model_params params) {
     ggml_time_init();
 
-    llama_model * model = new llama_model;
+    llama_model * model = new llama_model(params);
 
     unsigned cur_percentage = 0;
     if (params.progress_callback == NULL) {
@@ -11758,7 +9485,7 @@ struct llama_model * llama_model_load_from_file(
         LLAMA_LOG_INFO("%s: using device %s (%s) - %zu MiB free\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), free/1024/1024);
     }
 
-    int status = llama_model_load(path_model, *model, params);
+    const int status = llama_model_load(path_model, *model, params);
     GGML_ASSERT(status <= 0);
     if (status < 0) {
         if (status == -1) {
@@ -11774,7 +9501,7 @@ struct llama_model * llama_model_load_from_file(
     return model;
 }
 
-struct llama_context * llama_new_context_with_model(
+struct llama_context * llama_init_from_model(
                  struct llama_model * model,
         struct llama_context_params   params) {
 
@@ -12032,7 +9759,7 @@ struct llama_context * llama_new_context_with_model(
                 backend_ptrs.push_back(backend.get());
             }
 
-            const size_t max_nodes = llama_model_max_nodes(*model);
+            const size_t max_nodes = model->max_nodes();
 
             // buffer used to store the computation graph and the tensor meta data
             ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false));
@@ -12040,9 +9767,9 @@ struct llama_context * llama_new_context_with_model(
             // TODO: move these checks to ggml_backend_sched
             // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
             bool pipeline_parallel =
-                llama_get_device_count(*model) > 1 &&
-                model->n_gpu_layers > (int)model->hparams.n_layer &&
-                model->split_mode == LLAMA_SPLIT_MODE_LAYER &&
+                model->n_devices() > 1 &&
+                model->params.n_gpu_layers > (int)model->hparams.n_layer &&
+                model->params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
                 params.offload_kqv;
 
             // pipeline parallelism requires support for async compute and events in all devices
@@ -12073,7 +9800,7 @@ struct llama_context * llama_new_context_with_model(
             // initialize scheduler with the worst-case graph
             uint32_t n_seqs = 1; // TODO: worst-case number of sequences
             uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
-            llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
+            llama_token token = ctx->model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
 
             llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
             ggml_cgraph * gf_pp = llama_build_graph(*ctx, ubatch_pp, true);
@@ -12125,6 +9852,12 @@ struct llama_context * llama_new_context_with_model(
     return ctx;
 }
 
+struct llama_context * llama_new_context_with_model(
+                 struct llama_model * model,
+        struct llama_context_params   params) {
+    return llama_init_from_model(model, params);
+}
+
 //
 // kv cache
 //
@@ -12222,166 +9955,18 @@ int32_t llama_decode(
     return ret;
 }
 
-//
-// vocab
-//
-
-// TODO: tmp bridges below until `struct llama_vocab` is exposed through the public API
-
-const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
-    return llama_token_get_text_impl(model->vocab, token);
-}
-
-float llama_token_get_score(const struct llama_model * model, llama_token token) {
-    return llama_token_get_score_impl(model->vocab, token);
-}
-
-enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token) {
-    return llama_token_get_attr_impl(model->vocab, token);
-}
-
-bool llama_token_is_eog(const struct llama_model * model, llama_token token) {
-    return llama_token_is_eog_impl(model->vocab, token);
-}
-
-bool llama_token_is_control(const struct llama_model * model, llama_token token) {
-    return llama_token_is_control_impl(model->vocab, token);
-}
-
-llama_token llama_token_bos(const struct llama_model * model) {
-    return llama_token_bos_impl(model->vocab);
-}
-
-llama_token llama_token_eos(const struct llama_model * model) {
-    return llama_token_eos_impl(model->vocab);
-}
-
-llama_token llama_token_eot(const struct llama_model * model) {
-    return llama_token_eot_impl(model->vocab);
-}
-
-llama_token llama_token_cls(const struct llama_model * model) {
-    return llama_token_cls_impl(model->vocab);
-}
-
-llama_token llama_token_sep(const struct llama_model * model) {
-    return llama_token_sep_impl(model->vocab);
-}
-
-llama_token llama_token_nl (const struct llama_model * model) {
-    return llama_token_nl_impl(model->vocab);
-}
-
-llama_token llama_token_pad(const struct llama_model * model) {
-    return llama_token_pad_impl(model->vocab);
-}
-
-bool llama_add_bos_token(const struct llama_model * model) {
-    return llama_add_bos_token_impl(model->vocab);
-}
-
-bool llama_add_eos_token(const struct llama_model * model) {
-    return llama_add_eos_token_impl(model->vocab);
-}
-
-llama_token llama_token_prefix(const struct llama_model * model) {
-    return llama_token_prefix_impl(model->vocab);
-}
-
-llama_token llama_token_middle(const struct llama_model * model) {
-    return llama_token_middle_impl(model->vocab);
-}
-
-llama_token llama_token_suffix(const struct llama_model * model) {
-    return llama_token_suffix_impl(model->vocab);
-}
-
-llama_token llama_token_fim_pre(const struct llama_model * model) {
-    return llama_token_fim_pre_impl(model->vocab);
-}
-
-llama_token llama_token_fim_suf(const struct llama_model * model) {
-    return llama_token_fim_suf_impl(model->vocab);
-}
-
-llama_token llama_token_fim_mid(const struct llama_model * model) {
-    return llama_token_fim_mid_impl(model->vocab);
-}
-
-llama_token llama_token_fim_pad(const struct llama_model * model) {
-    return llama_token_fim_pad_impl(model->vocab);
-}
-
-llama_token llama_token_fim_rep(const struct llama_model * model) {
-    return llama_token_fim_rep_impl(model->vocab);
-}
-
-llama_token llama_token_fim_sep(const struct llama_model * model) {
-    return llama_token_fim_sep_impl(model->vocab);
-}
-
-//
-// tokenization
-//
-
-int32_t llama_tokenize(
-    const struct llama_model * model,
-                  const char * text,
-                     int32_t   text_len,
-                 llama_token * tokens,
-                     int32_t   n_tokens_max,
-                        bool   add_special,
-                        bool   parse_special) {
-    return llama_tokenize_impl(model->vocab, text, text_len, tokens, n_tokens_max, add_special, parse_special);
-}
-
-int32_t llama_token_to_piece(
-    const struct llama_model * model,
-                 llama_token   token,
-                        char * buf,
-                     int32_t   length,
-                     int32_t   lstrip,
-                        bool   special) {
-    return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
-}
-
-int32_t llama_detokenize(
-    const struct llama_model * model,
-           const llama_token * tokens,
-                     int32_t   n_tokens,
-                        char * text,
-                     int32_t   text_len_max,
-                        bool   remove_special,
-                        bool   unparse_special) {
-    return llama_detokenize_impl(model->vocab, tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
-}
-
 //
 // chat templates
 //
 
 int32_t llama_chat_apply_template(
-                const struct llama_model * model,
                               const char * tmpl,
          const struct llama_chat_message * chat,
                                   size_t   n_msg,
                                     bool   add_ass,
                                     char * buf,
                                  int32_t   length) {
-    std::string curr_tmpl(tmpl == nullptr ? "" : tmpl);
-    if (tmpl == nullptr) {
-        GGML_ASSERT(model != nullptr);
-
-        // load template from model, if available
-        const auto & it = model->gguf_kv.find("tokenizer.chat_template");
-        if (it != model->gguf_kv.end() && it->second.size() > 0) {
-            curr_tmpl = it->second;
-        }
-        else {
-            // worst case: there is no information about template, we will use chatml by default
-            curr_tmpl = "chatml";  // see llm_chat_apply_template
-        }
-    }
+    const std::string curr_tmpl(tmpl == nullptr ? "chatml" : tmpl);
 
     // format the chat to string
     std::vector chat_vec;
@@ -12405,23 +9990,6 @@ int32_t llama_chat_apply_template(
     return res;
 }
 
-//
-// sampling
-//
-
-// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp
-struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) {
-    return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
-}
-
-struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model) {
-    return llama_sampler_init_infill_impl(model->vocab);
-}
-
-struct llama_sampler * llama_sampler_init_dry(const struct llama_model * model, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
-    return llama_sampler_init_dry_impl(model->vocab, llama_n_ctx_train(model), dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers, num_breakers);
-}
-
 //
 // model split
 //
@@ -12434,16 +10002,16 @@ int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix,
     return 0;
 }
 
-int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int split_no, int split_count) {
+int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count) {
     std::string str_split_path(split_path);
     char postfix[32];
     snprintf(postfix, 32, "-%05d-of-%05d.gguf", split_no + 1, split_count);
     std::string str_postfix(postfix);
 
-    // check if dest ends with postfix
+    // check if split_prefix ends with postfix
     int size_prefix = str_split_path.size() - str_postfix.size();
     if (size_prefix > 0 && str_split_path.find(str_postfix, size_prefix) != std::string::npos) {
-        snprintf(dest, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path);
+        snprintf(split_prefix, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path);
         return size_prefix;
     }
 
diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h
index 0295a51fbee..a184884c77a 100644
--- a/examples/talk-llama/llama.h
+++ b/examples/talk-llama/llama.h
@@ -56,7 +56,7 @@ extern "C" {
     // TODO: show sample usage
     //
 
-    // struct llama_vocab; // TODO: add in the future
+    struct llama_vocab;
     struct llama_model;
     struct llama_context;
     struct llama_sampler;
@@ -385,8 +385,7 @@ extern "C" {
     } llama_chat_message;
 
     // lora adapter
-    // TODO: rename to llama_adapter_lora
-    struct llama_lora_adapter;
+    struct llama_adapter_lora;
 
     // Helpers for getting default parameters
     // TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172)
@@ -400,18 +399,19 @@ extern "C" {
     // Call once at the start of the program
     LLAMA_API void llama_backend_init(void);
 
+    // Call once at the end of the program - currently only used for MPI
+    LLAMA_API void llama_backend_free(void);
+
     //optional:
     LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa);
 
     // Optional: an auto threadpool gets created in ggml if not passed explicitly
     LLAMA_API void llama_attach_threadpool(
-               struct   llama_context * ctx,
-            ggml_threadpool_t   threadpool,
-            ggml_threadpool_t   threadpool_batch);
-    LLAMA_API void llama_detach_threadpool(struct llama_context * ctx);
+            struct llama_context * ctx,
+               ggml_threadpool_t   threadpool,
+               ggml_threadpool_t   threadpool_batch);
 
-    // Call once at the end of the program - currently only used for MPI
-    LLAMA_API void llama_backend_free(void);
+    LLAMA_API void llama_detach_threadpool(struct llama_context * ctx);
 
     DEPRECATED(LLAMA_API struct llama_model * llama_load_model_from_file(
                              const char * path_model,
@@ -427,11 +427,15 @@ extern "C" {
 
     LLAMA_API void llama_model_free(struct llama_model * model);
 
-    // TODO: rename to llama_init_from_model
-    LLAMA_API struct llama_context * llama_new_context_with_model(
+    LLAMA_API struct llama_context * llama_init_from_model(
                      struct llama_model * model,
             struct llama_context_params   params);
 
+    DEPRECATED(LLAMA_API struct llama_context * llama_new_context_with_model(
+                     struct llama_model * model,
+            struct llama_context_params   params),
+            "use llama_init_from_model instead");
+
     // Frees all allocated memory
     LLAMA_API void llama_free(struct llama_context * ctx);
 
@@ -449,20 +453,30 @@ extern "C" {
     LLAMA_API uint32_t llama_n_ubatch   (const struct llama_context * ctx);
     LLAMA_API uint32_t llama_n_seq_max  (const struct llama_context * ctx);
 
-    LLAMA_API int32_t llama_n_vocab    (const struct llama_model * model);
-    LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
-    LLAMA_API int32_t llama_n_embd     (const struct llama_model * model);
-    LLAMA_API int32_t llama_n_layer    (const struct llama_model * model);
-    LLAMA_API int32_t llama_n_head     (const struct llama_model * model);
+    DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead");
+    DEPRECATED(LLAMA_API int32_t llama_n_embd     (const struct llama_model * model), "use llama_model_n_embd instead");
+    DEPRECATED(LLAMA_API int32_t llama_n_layer    (const struct llama_model * model), "use llama_model_n_layer instead");
+    DEPRECATED(LLAMA_API int32_t llama_n_head     (const struct llama_model * model), "use llama_model_n_head instead");
+
+    DEPRECATED(LLAMA_API int32_t llama_n_vocab    (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
+
+    LLAMA_API const struct llama_model * llama_get_model   (const struct llama_context * ctx);
+    LLAMA_API enum llama_pooling_type    llama_pooling_type(const struct llama_context * ctx);
 
-    LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
+    LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
+    LLAMA_API enum llama_rope_type       llama_model_rope_type(const struct llama_model * model);
 
-    LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
-    LLAMA_API enum llama_vocab_type   llama_vocab_type  (const struct llama_model * model);
-    LLAMA_API enum llama_rope_type    llama_rope_type   (const struct llama_model * model);
+    LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
+    LLAMA_API int32_t llama_model_n_embd     (const struct llama_model * model);
+    LLAMA_API int32_t llama_model_n_layer    (const struct llama_model * model);
+    LLAMA_API int32_t llama_model_n_head     (const struct llama_model * model);
 
     // Get the model's RoPE frequency scaling factor
-    LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
+    LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
+
+    LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab);
+
+    LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);
 
     // Functions to access the model's GGUF metadata scalar values
     // - The functions return the length of the string on success, or -1 on failure
@@ -488,6 +502,9 @@ extern "C" {
     // Returns the total size of all the tensors in the model in bytes
     LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
 
+    // Get the default chat template. Returns nullptr if not available
+    LLAMA_API const char * llama_model_chat_template(const struct llama_model * model);
+
     // Returns the total number of parameters in the model
     LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
 
@@ -515,34 +532,31 @@ extern "C" {
     //
 
     // Load a LoRA adapter from file
-    // TODO: rename to llama_adapter_lora_init
-    LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init(
+    LLAMA_API struct llama_adapter_lora * llama_adapter_lora_init(
             struct llama_model * model,
             const char * path_lora);
 
+    // Manually free a LoRA adapter
+    // Note: loaded adapters will be free when the associated model is deleted
+    LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
+
+    // The following functions operate on a llama_context, hence the naming: llama_verb_...
+
     // Add a loaded LoRA adapter to given context
     // This will not modify model's weight
-    // TODO: rename to llama_set_adapter_lora
-    LLAMA_API int32_t llama_lora_adapter_set(
+    LLAMA_API int32_t llama_set_adapter_lora(
             struct llama_context * ctx,
-            struct llama_lora_adapter * adapter,
+            struct llama_adapter_lora * adapter,
             float scale);
 
     // Remove a specific LoRA adapter from given context
     // Return -1 if the adapter is not present in the context
-    // TODO: rename to llama_rm_adapter_lora
-    LLAMA_API int32_t llama_lora_adapter_remove(
+    LLAMA_API int32_t llama_rm_adapter_lora(
             struct llama_context * ctx,
-            struct llama_lora_adapter * adapter);
+            struct llama_adapter_lora * adapter);
 
     // Remove all LoRA adapters from given context
-    // TODO: rename to llama_clear_adapter_lora
-    LLAMA_API void llama_lora_adapter_clear(struct llama_context * ctx);
-
-    // Manually free a LoRA adapter
-    // Note: loaded adapters will be free when the associated model is deleted
-    // TODO: rename to llama_adapter_lora_free
-    LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter);
+    LLAMA_API void llama_clear_adapter_lora(struct llama_context * ctx);
 
     // Apply a loaded control vector to a llama_context, or if data is NULL, clear
     // the currently loaded vector.
@@ -550,9 +564,8 @@ extern "C" {
     // to an n_embd x n_layers buffer starting from layer 1.
     // il_start and il_end are the layer range the vector should apply to (both inclusive)
     // See llama_control_vector_load in common to load a control vector.
-    // TODO: rename to llama_adapter_cvec_apply
-    LLAMA_API int32_t llama_control_vector_apply(
-            struct llama_context * lctx,
+    LLAMA_API int32_t llama_apply_adapter_cvec(
+            struct llama_context * ctx,
                      const float * data,
                           size_t   len,
                          int32_t   n_embd,
@@ -908,41 +921,60 @@ extern "C" {
     // Vocab
     //
 
-    LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token);
+    LLAMA_API const char * llama_vocab_get_text(const struct llama_vocab * vocab, llama_token token);
 
-    LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
+    LLAMA_API float llama_vocab_get_score(const struct llama_vocab * vocab, llama_token token);
 
-    LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token);
+    LLAMA_API enum llama_token_attr llama_vocab_get_attr(const struct llama_vocab * vocab, llama_token token);
 
     // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
-    LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token);
+    LLAMA_API bool llama_vocab_is_eog(const struct llama_vocab * vocab, llama_token token);
 
     // Identify if Token Id is a control token or a render-able token
-    LLAMA_API bool llama_token_is_control(const struct llama_model * model, llama_token token);
+    LLAMA_API bool llama_vocab_is_control(const struct llama_vocab * vocab, llama_token token);
 
     // Special tokens
-    LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
-    LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
-    LLAMA_API llama_token llama_token_eot(const struct llama_model * model); // end-of-turn
-    LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
-    LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
-    LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
-    LLAMA_API llama_token llama_token_pad(const struct llama_model * model); // padding
-
-    LLAMA_API bool llama_add_bos_token(const struct llama_model * model);
-    LLAMA_API bool llama_add_eos_token(const struct llama_model * model);
-
-    // infill tokens
-    DEPRECATED(LLAMA_API llama_token llama_token_prefix(const struct llama_model * model), "use llama_token_fim_pre instead");
-    DEPRECATED(LLAMA_API llama_token llama_token_middle(const struct llama_model * model), "use llama_token_fim_mid instead");
-    DEPRECATED(LLAMA_API llama_token llama_token_suffix(const struct llama_model * model), "use llama_token_fim_suf instead");
-
-    LLAMA_API llama_token llama_token_fim_pre(const struct llama_model * model);
-    LLAMA_API llama_token llama_token_fim_suf(const struct llama_model * model);
-    LLAMA_API llama_token llama_token_fim_mid(const struct llama_model * model);
-    LLAMA_API llama_token llama_token_fim_pad(const struct llama_model * model);
-    LLAMA_API llama_token llama_token_fim_rep(const struct llama_model * model);
-    LLAMA_API llama_token llama_token_fim_sep(const struct llama_model * model);
+    LLAMA_API llama_token llama_vocab_bos(const struct llama_vocab * vocab); // beginning-of-sentence
+    LLAMA_API llama_token llama_vocab_eos(const struct llama_vocab * vocab); // end-of-sentence
+    LLAMA_API llama_token llama_vocab_eot(const struct llama_vocab * vocab); // end-of-turn
+    LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator
+    LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line
+    LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding
+
+    LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
+    LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
+
+    LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab);
+    LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab);
+    LLAMA_API llama_token llama_vocab_fim_mid(const struct llama_vocab * vocab);
+    LLAMA_API llama_token llama_vocab_fim_pad(const struct llama_vocab * vocab);
+    LLAMA_API llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab);
+    LLAMA_API llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab);
+
+    DEPRECATED(LLAMA_API const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token), "use llama_vocabable_get_text instead");
+    DEPRECATED(LLAMA_API float llama_token_get_score(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_score instead");
+    DEPRECATED(LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_attr instead");
+    DEPRECATED(LLAMA_API bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_eog instead");
+    DEPRECATED(LLAMA_API bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_control instead");
+    DEPRECATED(LLAMA_API llama_token llama_token_bos(const struct llama_vocab * vocab), "use llama_vocab_bos instead");
+    DEPRECATED(LLAMA_API llama_token llama_token_eos(const struct llama_vocab * vocab), "use llama_vocab_eos instead");
+    DEPRECATED(LLAMA_API llama_token llama_token_eot(const struct llama_vocab * vocab), "use llama_vocab_eot instead");
+    DEPRECATED(LLAMA_API llama_token llama_token_cls(const struct llama_vocab * vocab), "use llama_vocab_cls instead");
+    DEPRECATED(LLAMA_API llama_token llama_token_sep(const struct llama_vocab * vocab), "use llama_vocab_sep instead");
+    DEPRECATED(LLAMA_API llama_token llama_token_nl (const struct llama_vocab * vocab), "use llama_vocab_nl instead");
+    DEPRECATED(LLAMA_API llama_token llama_token_pad(const struct llama_vocab * vocab), "use llama_vocab_pad instead");
+    DEPRECATED(LLAMA_API bool llama_add_bos_token(const struct llama_vocab * vocab), "use llama_vocab_get_add_bos instead");
+    DEPRECATED(LLAMA_API bool llama_add_eos_token(const struct llama_vocab * vocab), "use llama_vocab_get_add_eos instead");
+    DEPRECATED(LLAMA_API llama_token llama_token_fim_pre(const struct llama_vocab * vocab), "use llama_vocab_fim_pre instead");
+    DEPRECATED(LLAMA_API llama_token llama_token_fim_suf(const struct llama_vocab * vocab), "use llama_vocab_fim_suf instead");
+    DEPRECATED(LLAMA_API llama_token llama_token_fim_mid(const struct llama_vocab * vocab), "use llama_vocab_fim_mid instead");
+    DEPRECATED(LLAMA_API llama_token llama_token_fim_pad(const struct llama_vocab * vocab), "use llama_vocab_fim_pad instead");
+    DEPRECATED(LLAMA_API llama_token llama_token_fim_rep(const struct llama_vocab * vocab), "use llama_vocab_fim_rep instead");
+    DEPRECATED(LLAMA_API llama_token llama_token_fim_sep(const struct llama_vocab * vocab), "use llama_vocab_fim_sep instead");
+
+    // CLS is equivalent to BOS
+    DEPRECATED(LLAMA_API llama_token llama_vocab_cls(const struct llama_vocab * vocab), // classification
+            "use llama_vocab_bos instead");
 
     //
     // Tokenization
@@ -958,7 +990,7 @@ extern "C" {
     /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
     ///                      as plaintext. Does not insert a leading space.
     LLAMA_API int32_t llama_tokenize(
-        const struct llama_model * model,
+        const struct llama_vocab * vocab,
                       const char * text,
                          int32_t   text_len,
                      llama_token * tokens,
@@ -972,7 +1004,7 @@ extern "C" {
     // User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix')
     // @param special If true, special tokens are rendered in the output.
     LLAMA_API int32_t llama_token_to_piece(
-              const struct llama_model * model,
+              const struct llama_vocab * vocab,
                            llama_token   token,
                                   char * buf,
                                int32_t   length,
@@ -986,7 +1018,7 @@ extern "C" {
     /// @param remove_special Allow to remove BOS and EOS tokens if model is configured to do so.
     /// @param unparse_special If true, special tokens are rendered in the output.
     LLAMA_API int32_t llama_detokenize(
-        const struct llama_model * model,
+        const struct llama_vocab * vocab,
                const llama_token * tokens,
                          int32_t   n_tokens,
                             char * text,
@@ -1009,7 +1041,6 @@ extern "C" {
     /// @param length The size of the allocated buffer
     /// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.
     LLAMA_API int32_t llama_chat_apply_template(
-              const struct llama_model * model,
                             const char * tmpl,
        const struct llama_chat_message * chat,
                                 size_t   n_msg,
@@ -1057,7 +1088,6 @@ extern "C" {
     //    llama_sampler_free(smpl);
     //
     // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
-    // TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab
     //
 
     typedef void * llama_sampler_context_t;
@@ -1157,7 +1187,7 @@ extern "C" {
                                float   eta);
 
     LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
-            const struct llama_model * model,
+            const struct llama_vocab * vocab,
                           const char * grammar_str,
                           const char * grammar_root);
 
@@ -1169,8 +1199,9 @@ extern "C" {
                                float   penalty_present); // 0.0 = disabled
 
     ///  @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
-    LLAMA_API struct llama_sampler *    llama_sampler_init_dry(
-            const struct llama_model *  model,
+    LLAMA_API struct llama_sampler * llama_sampler_init_dry(
+            const struct llama_vocab *  vocab,
+                             int32_t    n_ctx_train,
                                float    dry_multiplier,
                                float    dry_base,
                              int32_t    dry_allowed_length,
@@ -1204,7 +1235,7 @@ extern "C" {
     // 3. discard non-EOG tokens with low prob
     // 4. if no tokens are left -> pick EOT
     //
-    LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);
+    LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab);
 
     // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
     LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
diff --git a/examples/talk-llama/talk-llama.cpp b/examples/talk-llama/talk-llama.cpp
index e97ffae893b..dcdaec487cb 100644
--- a/examples/talk-llama/talk-llama.cpp
+++ b/examples/talk-llama/talk-llama.cpp
@@ -17,15 +17,16 @@
 #include 
 
 static std::vector llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) {
-    auto * model = llama_get_model(ctx);
+    const llama_model * model = llama_get_model(ctx);
+    const llama_vocab * vocab = llama_model_get_vocab(model);
 
     // upper limit for the number of tokens
     int n_tokens = text.length() + add_bos;
     std::vector result(n_tokens);
-    n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, false);
+    n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_bos, false);
     if (n_tokens < 0) {
         result.resize(-n_tokens);
-        int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, false);
+        int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_bos, false);
         GGML_ASSERT(check == -n_tokens);
     } else {
         result.resize(n_tokens);
@@ -34,11 +35,14 @@ static std::vector llama_tokenize(struct llama_context * ctx, const
 }
 
 static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
+    const llama_model * model = llama_get_model(ctx);
+    const llama_vocab * vocab = llama_model_get_vocab(model);
+
     std::vector result(8, 0);
-    const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), 0, false);
+    const int n_tokens = llama_token_to_piece(vocab, token, result.data(), result.size(), 0, false);
     if (n_tokens < 0) {
         result.resize(-n_tokens);
-        int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size(), 0, false);
+        int check = llama_token_to_piece(vocab, token, result.data(), result.size(), 0, false);
         GGML_ASSERT(check == -n_tokens);
     } else {
         result.resize(n_tokens);
@@ -310,6 +314,8 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
+    const llama_vocab * vocab_llama = llama_model_get_vocab(model_llama);
+
     llama_context_params lcparams = llama_context_default_params();
 
     // tune these to your liking
@@ -317,7 +323,7 @@ int main(int argc, char ** argv) {
     lcparams.n_threads  = params.n_threads;
     lcparams.flash_attn = params.flash_attn;
 
-    struct llama_context * ctx_llama = llama_new_context_with_model(model_llama, lcparams);
+    struct llama_context * ctx_llama = llama_init_from_model(model_llama, lcparams);
 
     // print some info about the processing
     {
@@ -727,7 +733,7 @@ int main(int argc, char ** argv) {
 
                         const llama_token id = llama_sampler_sample(smpl, ctx_llama, -1);
 
-                        if (id != llama_token_eos(model_llama)) {
+                        if (id != llama_vocab_eos(vocab_llama)) {
                             // add it to the context
                             embd.push_back(id);
 
diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt
index 3935065332e..fe8acc8038b 100644
--- a/ggml/CMakeLists.txt
+++ b/ggml/CMakeLists.txt
@@ -243,7 +243,8 @@ set(GGML_PUBLIC_HEADERS
     include/ggml-metal.h
     include/ggml-rpc.h
     include/ggml-sycl.h
-    include/ggml-vulkan.h)
+    include/ggml-vulkan.h
+    include/gguf.h)
 
 set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
 #if (GGML_METAL)
diff --git a/ggml/include/ggml-cpp.h b/ggml/include/ggml-cpp.h
index 219361af43e..a12342c25de 100644
--- a/ggml/include/ggml-cpp.h
+++ b/ggml/include/ggml-cpp.h
@@ -7,6 +7,7 @@
 #include "ggml.h"
 #include "ggml-alloc.h"
 #include "ggml-backend.h"
+#include "gguf.h"
 #include 
 
 // Smart pointers for ggml types
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index c714fc8c837..8f8cb9e1aa1 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -241,12 +241,6 @@
 #define GGML_ROPE_TYPE_MROPE  8
 #define GGML_ROPE_TYPE_VISION 24
 
-#define GGUF_MAGIC "GGUF"
-
-#define GGUF_VERSION 3
-
-#define GGUF_DEFAULT_ALIGNMENT 32
-
 #define GGML_UNUSED(x) (void)(x)
 
 #define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
@@ -403,12 +397,6 @@ extern "C" {
         GGML_PREC_F32,
     };
 
-    enum ggml_backend_type {
-        GGML_BACKEND_TYPE_CPU = 0,
-        GGML_BACKEND_TYPE_GPU = 10,
-        GGML_BACKEND_TYPE_GPU_SPLIT = 20,
-    };
-
     // model file types
     enum ggml_ftype {
         GGML_FTYPE_UNKNOWN        = -1,
@@ -513,6 +501,7 @@ extern "C" {
         GGML_OP_GET_REL_POS,
         GGML_OP_ADD_REL_POS,
         GGML_OP_RWKV_WKV6,
+        GGML_OP_GATED_LINEAR_ATTN,
 
         GGML_OP_UNARY,
 
@@ -587,8 +576,6 @@ extern "C" {
     struct ggml_tensor {
         enum ggml_type type;
 
-        GGML_DEPRECATED(enum ggml_backend_type backend, "use the buffer type to find the storage location of the tensor");
-
         struct ggml_backend_buffer * buffer;
 
         int64_t ne[GGML_MAX_DIMS]; // number of elements
@@ -1873,6 +1860,15 @@ extern "C" {
             struct ggml_tensor  * td,
             struct ggml_tensor  * state);
 
+    GGML_API struct ggml_tensor * ggml_gated_linear_attn(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * k,
+            struct ggml_tensor  * v,
+            struct ggml_tensor  * q,
+            struct ggml_tensor  * g,
+            struct ggml_tensor  * state,
+            float scale);
+
     // custom operators
 
     typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);
@@ -2111,132 +2107,6 @@ extern "C" {
                    int64_t   n_per_row,
                const float * imatrix);
 
-    //
-    // gguf
-    //
-
-    enum gguf_type {
-        GGUF_TYPE_UINT8   = 0,
-        GGUF_TYPE_INT8    = 1,
-        GGUF_TYPE_UINT16  = 2,
-        GGUF_TYPE_INT16   = 3,
-        GGUF_TYPE_UINT32  = 4,
-        GGUF_TYPE_INT32   = 5,
-        GGUF_TYPE_FLOAT32 = 6,
-        GGUF_TYPE_BOOL    = 7,
-        GGUF_TYPE_STRING  = 8,
-        GGUF_TYPE_ARRAY   = 9,
-        GGUF_TYPE_UINT64  = 10,
-        GGUF_TYPE_INT64   = 11,
-        GGUF_TYPE_FLOAT64 = 12,
-        GGUF_TYPE_COUNT,       // marks the end of the enum
-    };
-
-    struct gguf_context;
-
-    struct gguf_init_params {
-        bool no_alloc;
-
-        // if not NULL, create a ggml_context and allocate the tensor data in it
-        struct ggml_context ** ctx;
-    };
-
-    GGML_API struct gguf_context * gguf_init_empty(void);
-    GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
-    //GGML_API struct gguf_context * gguf_init_from_buffer(..);
-
-    GGML_API void gguf_free(struct gguf_context * ctx);
-
-    GGML_API const char * gguf_type_name(enum gguf_type type);
-
-    GGML_API int    gguf_get_version    (const struct gguf_context * ctx);
-    GGML_API size_t gguf_get_alignment  (const struct gguf_context * ctx);
-    GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx);
-    GGML_API void * gguf_get_data       (const struct gguf_context * ctx);
-
-    GGML_API int          gguf_get_n_kv(const struct gguf_context * ctx);
-    GGML_API int          gguf_find_key(const struct gguf_context * ctx, const char * key);
-    GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int key_id);
-
-    GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int key_id);
-    GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id);
-
-    // will abort if the wrong type is used for the key
-    GGML_API uint8_t      gguf_get_val_u8  (const struct gguf_context * ctx, int key_id);
-    GGML_API int8_t       gguf_get_val_i8  (const struct gguf_context * ctx, int key_id);
-    GGML_API uint16_t     gguf_get_val_u16 (const struct gguf_context * ctx, int key_id);
-    GGML_API int16_t      gguf_get_val_i16 (const struct gguf_context * ctx, int key_id);
-    GGML_API uint32_t     gguf_get_val_u32 (const struct gguf_context * ctx, int key_id);
-    GGML_API int32_t      gguf_get_val_i32 (const struct gguf_context * ctx, int key_id);
-    GGML_API float        gguf_get_val_f32 (const struct gguf_context * ctx, int key_id);
-    GGML_API uint64_t     gguf_get_val_u64 (const struct gguf_context * ctx, int key_id);
-    GGML_API int64_t      gguf_get_val_i64 (const struct gguf_context * ctx, int key_id);
-    GGML_API double       gguf_get_val_f64 (const struct gguf_context * ctx, int key_id);
-    GGML_API bool         gguf_get_val_bool(const struct gguf_context * ctx, int key_id);
-    GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id);
-    GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id);
-    GGML_API int          gguf_get_arr_n   (const struct gguf_context * ctx, int key_id);
-    GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
-    GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
-
-    GGML_API int            gguf_get_n_tensors    (const struct gguf_context * ctx);
-    GGML_API int            gguf_find_tensor      (const struct gguf_context * ctx, const char * name);
-    GGML_API size_t         gguf_get_tensor_offset(const struct gguf_context * ctx, int i);
-    GGML_API char *         gguf_get_tensor_name  (const struct gguf_context * ctx, int i);
-    GGML_API enum ggml_type gguf_get_tensor_type  (const struct gguf_context * ctx, int i);
-
-    // removes key if it exists
-    GGML_API void gguf_remove_key(struct gguf_context * ctx, const char * key);
-
-    // overrides existing values or adds a new one
-    GGML_API void gguf_set_val_u8  (struct gguf_context * ctx, const char * key, uint8_t  val);
-    GGML_API void gguf_set_val_i8  (struct gguf_context * ctx, const char * key, int8_t   val);
-    GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val);
-    GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t  val);
-    GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val);
-    GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t  val);
-    GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float    val);
-    GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val);
-    GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t  val);
-    GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double   val);
-    GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool     val);
-    GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
-    GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n);
-    GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, int n);
-
-    // set or add KV pairs from another context
-    GGML_API void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src);
-
-    // manage tensor info
-    GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor);
-    GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type);
-    GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size);
-
-    // writing gguf files can be done in 2 ways:
-    //
-    // - write the entire gguf_context to a binary file in a single pass:
-    //
-    //   gguf_write_to_file(ctx, fname);
-    //
-    // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
-    //
-    //   FILE * f = fopen(fname, "wb");
-    //   fseek(f, gguf_get_meta_size(ctx), SEEK_SET);
-    //   fwrite(f, ...);
-    //   void * data = gguf_meta_get_meta_data(ctx);
-    //   fseek(f, 0, SEEK_SET);
-    //   fwrite(f, data, gguf_get_meta_size(ctx));
-    //   free(data);
-    //   fclose(f);
-    //
-
-    // write the entire context to a binary file
-    GGML_API void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
-
-    // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
-    GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
-    GGML_API void   gguf_get_meta_data(const struct gguf_context * ctx, void * data);
-
 #ifdef __cplusplus
     // restrict not standard in C++
 #    if defined(__GNUC__)
diff --git a/ggml/include/gguf.h b/ggml/include/gguf.h
new file mode 100644
index 00000000000..79ee202062b
--- /dev/null
+++ b/ggml/include/gguf.h
@@ -0,0 +1,202 @@
+// This file contains functionality related to "GGUF" files, the binary file format used by ggml.
+// GGUF files have the following structure:
+//
+// 1. File magic "GGUF" (4 bytes).
+// 2. File version (uint32_t).
+// 3. Number of ggml tensors in file (int64_t).
+// 4. Number of key-value-pairs in file (int64_t).
+// 5. For each KV pair:
+//   1. The key (string).
+//   2. The value type (gguf_type).
+//   3a. If the value type is GGUF_TYPE_ARRAY:
+//     1. The type of the array (gguf_type).
+//     2. The number of elements in the array (uint64_t).
+//     3. The binary representation of each element in the array.
+//   3b. Otherwise:
+//     1. The binary representation of the value.
+// 6. For each ggml tensor:
+//   1. The tensor name (string).
+//   2. The number of dimensions of the tensor (uint32_t).
+//   3. For each dimension:
+//     1. The size of the tensor in the dimension (int64_t).
+//   4. The tensor data type (ggml_type).
+//   5. The tensor data offset in the tensor data binary blob (uint64_t).
+// 7. The tensor data binary blob (optional, aligned).
+//
+// Strings are serialized as the string length (uint64_t) followed by the C string without the null terminator.
+// All enums are stored as int32_t.
+// All bool values are stored as int8_t.
+// If the special key "general.alignment" (uint32_t) is defined it is used for alignment,
+//   otherwise GGUF_DEFAULT_ALIGNMENT is used.
+//
+// Module maintainer: Johannes Gäßler (@JohannesGaessler, johannesg@5d6.de)
+
+#pragma once
+
+#include "ggml.h"
+
+#include 
+#include 
+
+#define GGUF_MAGIC   "GGUF"
+#define GGUF_VERSION 3
+
+#define GGUF_KEY_GENERAL_ALIGNMENT "general.alignment"
+
+#define GGUF_DEFAULT_ALIGNMENT 32
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+    // types that can be stored as GGUF KV data
+    enum gguf_type {
+        GGUF_TYPE_UINT8   = 0,
+        GGUF_TYPE_INT8    = 1,
+        GGUF_TYPE_UINT16  = 2,
+        GGUF_TYPE_INT16   = 3,
+        GGUF_TYPE_UINT32  = 4,
+        GGUF_TYPE_INT32   = 5,
+        GGUF_TYPE_FLOAT32 = 6,
+        GGUF_TYPE_BOOL    = 7,
+        GGUF_TYPE_STRING  = 8,
+        GGUF_TYPE_ARRAY   = 9,
+        GGUF_TYPE_UINT64  = 10,
+        GGUF_TYPE_INT64   = 11,
+        GGUF_TYPE_FLOAT64 = 12,
+        GGUF_TYPE_COUNT,       // marks the end of the enum
+    };
+
+    struct gguf_context;
+
+    struct gguf_init_params {
+        bool no_alloc;
+
+        // if not NULL, create a ggml_context and allocate the tensor data in it
+        struct ggml_context ** ctx;
+    };
+
+    GGML_API struct gguf_context * gguf_init_empty(void);
+    GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
+    //GGML_API struct gguf_context * gguf_init_from_buffer(..);
+
+    GGML_API void gguf_free(struct gguf_context * ctx);
+
+    GGML_API const char * gguf_type_name(enum gguf_type type);
+
+    GGML_API uint32_t gguf_get_version    (const struct gguf_context * ctx);
+    GGML_API size_t   gguf_get_alignment  (const struct gguf_context * ctx);
+    GGML_API size_t   gguf_get_data_offset(const struct gguf_context * ctx);
+
+    GGML_API int64_t      gguf_get_n_kv(const struct gguf_context * ctx);
+    GGML_API int64_t      gguf_find_key(const struct gguf_context * ctx, const char * key); // returns -1 if key is not found
+    GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int64_t key_id);
+
+    GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id);
+
+    // will abort if the wrong type is used for the key
+    GGML_API uint8_t      gguf_get_val_u8  (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API int8_t       gguf_get_val_i8  (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API uint16_t     gguf_get_val_u16 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API int16_t      gguf_get_val_i16 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API uint32_t     gguf_get_val_u32 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API int32_t      gguf_get_val_i32 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API float        gguf_get_val_f32 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API uint64_t     gguf_get_val_u64 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API int64_t      gguf_get_val_i64 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API double       gguf_get_val_f64 (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API bool         gguf_get_val_bool(const struct gguf_context * ctx, int64_t key_id);
+    GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int64_t key_id);
+    GGML_API const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id);
+    GGML_API size_t       gguf_get_arr_n   (const struct gguf_context * ctx, int64_t key_id);
+
+    // get raw pointer to the first element of the array with the given key_id
+    // for bool arrays, note that they are always stored as int8 on all platforms (usually this makes no difference)
+    GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id);
+
+    // get ith C string from array with given key_id
+    GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int64_t key_id, size_t i);
+
+    GGML_API int64_t        gguf_get_n_tensors    (const struct gguf_context * ctx);
+    GGML_API int64_t        gguf_find_tensor      (const struct gguf_context * ctx, const char * name); // returns -1 if the tensor is not found
+    GGML_API size_t         gguf_get_tensor_offset(const struct gguf_context * ctx, int64_t tensor_id);
+    GGML_API const char *   gguf_get_tensor_name  (const struct gguf_context * ctx, int64_t tensor_id);
+    GGML_API enum ggml_type gguf_get_tensor_type  (const struct gguf_context * ctx, int64_t tensor_id);
+    GGML_API size_t         gguf_get_tensor_size  (const struct gguf_context * ctx, int64_t tensor_id);
+
+    // removes key if it exists, returns id that the key had prior to removal (-1 if it didn't exist)
+    GGML_API int64_t gguf_remove_key(struct gguf_context * ctx, const char * key);
+
+    // overrides an existing KV pair or adds a new one, the new KV pair is always at the back
+    GGML_API void gguf_set_val_u8  (struct gguf_context * ctx, const char * key, uint8_t      val);
+    GGML_API void gguf_set_val_i8  (struct gguf_context * ctx, const char * key, int8_t       val);
+    GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t     val);
+    GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t      val);
+    GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t     val);
+    GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t      val);
+    GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float        val);
+    GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t     val);
+    GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t      val);
+    GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double       val);
+    GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool         val);
+    GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val);
+
+    // creates a new array with n elements of the given type and copies the corresponding number of bytes from data
+    GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, size_t n);
+
+    // creates a new array with n strings and copies the corresponding strings from data
+    GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, size_t n);
+
+    // set or add KV pairs from another context
+    GGML_API void gguf_set_kv(struct gguf_context * ctx, const struct gguf_context * src);
+
+    // add tensor to GGUF context, tensor name must be unique
+    GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor);
+
+    // after changing a tensor's type, the offsets of all tensors with higher indices are immediately recalculated
+    //   in such a way that the tensor data remains as one contiguous block (except for padding)
+    GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type);
+
+    // assumes that at least gguf_get_tensor_size bytes can be read from data
+    GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data);
+
+    // writing gguf files can be done in 3 ways:
+    //
+    // - write the entire gguf_context to a binary file in a single pass:
+    //
+    //   gguf_write_to_file(ctx, fname, /*only_meta =*/ false);
+    //
+    // - write only the meta data to a file, then re-open the file and append the tensor data:
+    //
+    //   gguf_write_to_file(ctx, fname, /*only_meta =*/ true);
+    //   FILE * f = fopen(fname, "ab");
+    //   fwrite(f, ...); // write tensor data
+    //   fclose(f);
+    //
+    // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data:
+    //
+    //   FILE * f = fopen(fname, "wb");
+    //   const size_t size_meta = gguf_get_meta_size(ctx);
+    //   fseek(f, size_meta, SEEK_SET);
+    //   fwrite(f, ...); // write tensor data
+    //   void * data = malloc(size_meta);
+    //   gguf_get_meta_data(ctx, data);
+    //   rewind(f);
+    //   fwrite(data, 1, data, f);
+    //   free(data);
+    //   fclose(f);
+    //
+
+    // write the entire context to a binary file
+    GGML_API bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta);
+
+    // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding
+    GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
+
+    // writes the meta data to pointer "data"
+    GGML_API void   gguf_get_meta_data(const struct gguf_context * ctx, void * data);
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 84101c32c2b..ae1cd23376c 100644
--- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt
@@ -208,6 +208,7 @@ add_library(ggml-base
             ../include/ggml-backend.h
             ../include/ggml-cpp.h
             ../include/ggml-opt.h
+            ../include/gguf.h
             ggml.c
             ggml-alloc.c
             ggml-backend.cpp
@@ -215,7 +216,8 @@ add_library(ggml-base
             ggml-threading.cpp
             ggml-threading.h
             ggml-quants.c
-            ggml-quants.h)
+            ggml-quants.h
+            gguf.cpp)
 
 target_include_directories(ggml-base PRIVATE .)
 
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
index 7ddd178b5f3..955ed505fa1 100644
--- a/ggml/src/ggml-backend-reg.cpp
+++ b/ggml/src/ggml-backend-reg.cpp
@@ -574,4 +574,9 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
     ggml_backend_load_best("opencl", silent, dir_path);
     ggml_backend_load_best("musa", silent, dir_path);
     ggml_backend_load_best("cpu", silent, dir_path);
+    // check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
+    const char * backend_path = std::getenv("GGML_BACKEND_PATH");
+    if (backend_path) {
+        ggml_backend_load(backend_path);
+    }
 }
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
index e2d6c405668..dba7be33b88 100644
--- a/ggml/src/ggml-backend.cpp
+++ b/ggml/src/ggml-backend.cpp
@@ -764,7 +764,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
         if (tensor->op != GGML_OP_ROPE && src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
             int src_backend_id = ggml_backend_sched_backend_from_buffer(sched, src, tensor);
             // check if a backend with higher prio wants to offload the op
-            if (src_backend_id == sched->n_backends - 1) {
+            if (src_backend_id == sched->n_backends - 1 && ggml_backend_buffer_is_host(src->buffer)) {
                 for (int b = 0; b < src_backend_id; b++) {
                     if (ggml_backend_supports_op(sched->backends[b], tensor) && ggml_backend_offload_op(sched->backends[b], tensor)) {
                         SET_CAUSE(tensor, "1.off");
diff --git a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp
index 622c63f1f8e..b311a5b1c4b 100644
--- a/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp
+++ b/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp
@@ -4169,6 +4169,8 @@ static ggml_backend_buffer_t ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(g
     buffer->buft              = buft;
     buffer->iface.init_tensor = ggml_backend_cpu_aarch64_buffer_init_tensor;
     buffer->iface.set_tensor  = ggml_backend_cpu_aarch64_buffer_set_tensor;
+    buffer->iface.get_tensor  = nullptr;
+    buffer->iface.cpy_tensor  = nullptr;
     return buffer;
 }
 
diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c
index b7fefb9ddfd..2966ff7682d 100644
--- a/ggml/src/ggml-cpu/ggml-cpu.c
+++ b/ggml/src/ggml-cpu/ggml-cpu.c
@@ -11803,9 +11803,9 @@ static void ggml_compute_forward_add_rel_pos(
 static void ggml_compute_forward_rwkv_wkv6_f32(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
-    const int64_t T = dst->src[1]->ne[3];
+    const int64_t T = dst->src[1]->ne[2];
     const int64_t C = dst->ne[0];
-    const int64_t HEADS = dst->src[1]->ne[2];
+    const int64_t HEADS = dst->src[1]->ne[1];
     const int64_t n_seqs = dst->src[5]->ne[1];
     const int64_t head_size = C / HEADS;
 
@@ -12000,6 +12000,197 @@ static void ggml_compute_forward_rwkv_wkv6(
     }
 }
 
+// ggml_compute_forward_gla
+
+static void ggml_compute_forward_gla_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+    const int64_t T = dst->src[1]->ne[2];
+    const int64_t C = dst->ne[0];
+    const int64_t HEADS = dst->src[1]->ne[1];
+    const int64_t n_seqs = dst->src[4]->ne[1];
+    const int64_t head_size = C / HEADS;
+    const float scale = ggml_get_op_params_f32(dst, 0);
+
+    float * dst_data = (float *) dst->data;
+    float * state = ((float *) dst->data) + C * T;
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    if (ith >= HEADS) {
+        return;
+    }
+
+    const int h_start = (HEADS * ith) / nth;
+    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
+                (HEADS * (ith + 1)) / nth : HEADS;
+
+    float * k = (float *) dst->src[0]->data;
+    float * v = (float *) dst->src[1]->data;
+    float * q = (float *) dst->src[2]->data;
+    float * g = (float *) dst->src[3]->data;
+
+    size_t t_stride = HEADS * head_size; // Same to C
+
+    size_t h_stride = C / HEADS;
+    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
+    size_t h_stride_2d = head_size * head_size;
+
+    if (ith == 0) {
+        memset(dst_data, 0, T * C * sizeof(float));
+    }
+    ggml_barrier(params->threadpool);
+
+
+    #if defined(__AVX__) && !defined(__AVX512F__)
+        #define GGML_F32X GGML_F32x8
+        #define GGML_F32X_SET1 GGML_F32x8_SET1
+        #define GGML_F32X_LOAD GGML_F32x8_LOAD
+        #define GGML_F32X_STORE GGML_F32x8_STORE
+        #define GGML_F32X_MUL GGML_F32x8_MUL
+        #define GGML_F32X_FMA GGML_F32x8_FMA
+        #define GLA_VECTOR_SIZE 8
+    #elif defined(__AVX512F__)
+        #define GGML_F32X GGML_F32x16
+        #define GGML_F32X_SET1 GGML_F32x16_SET1
+        #define GGML_F32X_LOAD GGML_F32x16_LOAD
+        #define GGML_F32X_STORE GGML_F32x16_STORE
+        #define GGML_F32X_MUL GGML_F32x16_MUL
+        #define GGML_F32X_FMA GGML_F32x16_FMA
+        #define GLA_VECTOR_SIZE 16
+    #elif defined(__ARM_NEON) && defined(__aarch64__)
+        #define GGML_F32X GGML_F32x4
+        #define GGML_F32X_SET1 GGML_F32x4_SET1
+        #define GGML_F32X_LOAD GGML_F32x4_LOAD
+        #define GGML_F32X_STORE GGML_F32x4_STORE
+        #define GGML_F32X_MUL GGML_F32x4_MUL
+        #define GGML_F32X_FMA GGML_F32x4_FMA
+        #define GLA_VECTOR_SIZE 4
+    #endif
+
+    #ifdef GLA_VECTOR_SIZE
+        const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
+
+        for (int64_t t = 0; t < T; t++) {
+            size_t t_offset = t * t_stride;
+            size_t state_offset = head_size * C * (t / (T / n_seqs));
+            float * state_cur = state + state_offset;
+            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
+
+            for (int64_t h = h_start; h < h_end; h++) {
+                size_t h_offset = h * h_stride;
+                size_t t_h_offset = t_offset + h_offset;
+                size_t h_2d_offset = h * h_stride_2d;
+
+                for (int64_t i = 0; i < head_size; i++) {
+                    size_t t_h_i_offset = t_h_offset + i;
+                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+                    float k_val = k[t_h_i_offset];
+                    float q_val = q[t_h_i_offset] * scale;
+                    float g_val = g[t_h_i_offset];
+
+                    // Broadcast scalar values to vectors
+                    GGML_F32X k_vec = GGML_F32X_SET1(k_val);
+                    GGML_F32X q_vec = GGML_F32X_SET1(q_val);
+                    GGML_F32X g_vec = GGML_F32X_SET1(g_val);
+
+                    for (int64_t j = 0; j < vec_count; j++) {
+                        size_t base_j = j * GLA_VECTOR_SIZE;
+                        size_t t_h_j_offset = t_h_offset + base_j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
+
+                        // Load x elements at once
+                        GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
+                        GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
+                        GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
+
+                        // Compute kv = v * k
+                        GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
+
+                        // Compute temp = prev_state * g + kv
+                        GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
+
+                        // Update dst: dst += temp * q
+                        dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
+                        GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
+
+                        // Update state
+                        GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
+                    }
+
+                    // Handle remaining elements, this will not be used.
+                    for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
+                        size_t t_h_j_offset = t_h_offset + j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
+                        float v_val = v[t_h_j_offset];
+                        float kv_val = v_val * k_val;
+                        float prev_state_val = state_prev[h_2d_i_j_offset];
+                        float temp_val = kv_val + prev_state_val * g_val;
+                        dst_data[t_h_j_offset] += temp_val * q_val;
+                        state_cur[h_2d_i_j_offset] = temp_val;
+                    }
+                }
+            }
+        }
+
+    #else
+        for (int64_t t = 0; t < T; t++) {
+            size_t t_offset = t * t_stride;
+            size_t state_offset = head_size * C * (t / (T / n_seqs));
+            float * state_cur = state + state_offset;
+            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
+
+            for (int64_t h = h_start; h < h_end; h++) {
+                size_t h_offset = h * h_stride;
+                size_t t_h_offset = t_offset + h_offset;
+                size_t h_2d_offset = h * h_stride_2d;
+
+                for (int64_t i = 0; i < head_size; i++) {
+                    size_t t_h_i_offset = t_h_offset + i;
+                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
+
+                    float k_val = k[t_h_i_offset];
+                    float q_val = q[t_h_i_offset] * scale;
+                    float g_val = g[t_h_i_offset];
+
+                    for (int64_t j = 0; j < head_size; j++) {
+                        size_t t_h_j_offset = t_h_offset + j;
+                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
+
+                        float v_val = v[t_h_j_offset];
+                        float kv_val = v_val * k_val;
+                        float prev_state_val = state_prev[h_2d_i_j_offset];
+                        float temp_val = prev_state_val * g_val + kv_val;
+                        dst_data[t_h_j_offset] += temp_val * q_val;
+                        state_cur[h_2d_i_j_offset] = temp_val;
+                    }
+                }
+            }
+        }
+    #endif
+}
+
+
+static void ggml_compute_forward_gla(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_gla_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_map_unary
 
 static void ggml_compute_forward_map_unary_f32(
@@ -12749,6 +12940,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_rwkv_wkv6(params, tensor);
             } break;
+        case GGML_OP_GATED_LINEAR_ATTN:
+            {
+                ggml_compute_forward_gla(params, tensor);
+            } break;
         case GGML_OP_MAP_UNARY:
             {
                 ggml_unary_op_f32_t fun;
@@ -13047,6 +13242,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_WIN_UNPART:
         case GGML_OP_GET_REL_POS:
         case GGML_OP_RWKV_WKV6:
+        case GGML_OP_GATED_LINEAR_ATTN:
         case GGML_OP_MAP_UNARY:
         case GGML_OP_MAP_BINARY:
         case GGML_OP_MAP_CUSTOM1_F32:
diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp
index 8fce576c3e4..c22a662876c 100644
--- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp
+++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp
@@ -54,6 +54,7 @@
 #include "ggml-quants.h"
 
 #include 
+#include 
 
 #ifdef _MSC_VER
 #define NOINLINE __declspec(noinline)
@@ -1051,6 +1052,704 @@ class tinyBLAS_Q0_AVX {
       } \
    } \
 
+template 
+class tinyBLAS_Q0_PPC {
+  public:
+    tinyBLAS_Q0_PPC(int64_t k,
+                const TA *A, int64_t lda,
+                const TB *B, int64_t ldb,
+                TC *C, int64_t ldc,
+                int ith, int nth)
+        : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
+    }
+
+    void matmul(int64_t m, int64_t n) {
+        mnpack(0, m, 0, n);
+    }
+
+  private:
+
+    template
+    inline void save_res(int ii, int jj, int idx, vector float* fin_res) {
+       for (int I = 0; I < RM; I++) {
+          for (int J = 0; J < RN; J++) {
+             *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J);
+          }
+       }
+    }
+
+    template
+    inline void compute(acc_t* ACC, int c_idx, int s_idx, std::array& comparray, vector float* vs, vector float* fin_res) {
+       vector signed int vec_C[4];
+       vector float CA[4] = {0};
+       vector float res[4] = {0};
+       __builtin_mma_disassemble_acc(vec_C, ACC);
+       for (int i = 0; i < 4; i++) {
+          CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0));
+          res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
+          fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]);
+       }
+    }
+
+    template
+    void packNormal(const TA* a, int64_t lda, int rows, int cols, VA* vec, bool flip) {
+        int64_t i, j;
+        TA *aoffset = NULL;
+        VA *vecOffset = NULL;
+        TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
+        TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
+        __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
+        VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0};
+        VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0};
+        VB t1, t2, t3, t4, t5, t6, t7, t8;
+        vector unsigned char xor_vector;
+        uint8_t flip_vec = 0x80;
+        xor_vector = vec_splats(flip_vec);
+        vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23};
+        vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31};
+        vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27};
+        vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31};
+
+        aoffset = const_cast(a);
+        vecOffset = vec;
+        j = (rows >> 3);
+        if (j > 0) {
+            do {
+            aoffset1 = aoffset;
+            aoffset2 = aoffset1 + lda;
+            aoffset3 = aoffset2 + lda;
+            aoffset4 = aoffset3 + lda;
+            aoffset5 = aoffset4 + lda;
+            aoffset6 = aoffset5 + lda;
+            aoffset7 = aoffset6 + lda;
+            aoffset8 = aoffset7 + lda;
+            aoffset += 8 * lda;
+
+            i = (cols >> 3);
+            if (i > 0) {
+               do {
+                    C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
+                    C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
+                    C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
+                    C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
+                    C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5->qs);
+                    C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6->qs);
+                    C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7->qs);
+                    C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8->qs);
+
+                    __builtin_vsx_disassemble_pair(c1, &C1);
+                    __builtin_vsx_disassemble_pair(c2, &C2);
+                    __builtin_vsx_disassemble_pair(c3, &C3);
+                    __builtin_vsx_disassemble_pair(c4, &C4);
+                    __builtin_vsx_disassemble_pair(c5, &C5);
+                    __builtin_vsx_disassemble_pair(c6, &C6);
+                    __builtin_vsx_disassemble_pair(c7, &C7);
+                    __builtin_vsx_disassemble_pair(c8, &C8);
+
+                    t1 = vec_perm(c1[0], c2[0], swiz1);
+                    t2 = vec_perm(c1[0], c2[0], swiz2);
+                    t3 = vec_perm(c3[0], c4[0], swiz1);
+                    t4 = vec_perm(c3[0], c4[0], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    if (flip == true) {
+                       t5 = vec_xor(t5, xor_vector);
+                       t6 = vec_xor(t6, xor_vector);
+                       t7 = vec_xor(t7, xor_vector);
+                       t8 = vec_xor(t8, xor_vector);
+                    }
+                    vec_xst(t5, 0, vecOffset);
+                    vec_xst(t6, 0, vecOffset+16);
+                    vec_xst(t7, 0, vecOffset+32);
+                    vec_xst(t8, 0, vecOffset+48);
+
+                    t1 = vec_perm(c1[1], c2[1], swiz1);
+                    t2 = vec_perm(c1[1], c2[1], swiz2);
+                    t3 = vec_perm(c3[1], c4[1], swiz1);
+                    t4 = vec_perm(c3[1], c4[1], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    if (flip == true) {
+                       t5 = vec_xor(t5, xor_vector);
+                       t6 = vec_xor(t6, xor_vector);
+                       t7 = vec_xor(t7, xor_vector);
+                       t8 = vec_xor(t8, xor_vector);
+                    }
+                    vec_xst(t5, 0, vecOffset+64);
+                    vec_xst(t6, 0, vecOffset+80);
+                    vec_xst(t7, 0, vecOffset+96);
+                    vec_xst(t8, 0, vecOffset+112);
+
+                    t1 = vec_perm(c5[0], c6[0], swiz1);
+                    t2 = vec_perm(c5[0], c6[0], swiz2);
+                    t3 = vec_perm(c7[0], c8[0], swiz1);
+                    t4 = vec_perm(c7[0], c8[0], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    if (flip == true) {
+                       t5 = vec_xor(t5, xor_vector);
+                       t6 = vec_xor(t6, xor_vector);
+                       t7 = vec_xor(t7, xor_vector);
+                       t8 = vec_xor(t8, xor_vector);
+                    }
+                    vec_xst(t5, 0, vecOffset+128);
+                    vec_xst(t6, 0, vecOffset+144);
+                    vec_xst(t7, 0, vecOffset+160);
+                    vec_xst(t8, 0, vecOffset+176);
+
+                    t1 = vec_perm(c5[1], c6[1], swiz1);
+                    t2 = vec_perm(c5[1], c6[1], swiz2);
+                    t3 = vec_perm(c7[1], c8[1], swiz1);
+                    t4 = vec_perm(c7[1], c8[1], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    if (flip == true) {
+                       t5 = vec_xor(t5, xor_vector);
+                       t6 = vec_xor(t6, xor_vector);
+                       t7 = vec_xor(t7, xor_vector);
+                       t8 = vec_xor(t8, xor_vector);
+                    }
+                    vec_xst(t5, 0, vecOffset+192);
+                    vec_xst(t6, 0, vecOffset+208);
+                    vec_xst(t7, 0, vecOffset+224);
+                    vec_xst(t8, 0, vecOffset+240);
+
+                    aoffset1 += lda;
+                    aoffset2 += lda;
+                    aoffset3 += lda;
+                    aoffset4 += lda;
+                    aoffset5 += lda;
+                    aoffset6 += lda;
+                    aoffset7 += lda;
+                    aoffset8 += lda;
+                    vecOffset += 256;
+                    i--;
+               } while(i > 0);
+            }
+            j--;
+        } while(j > 0);
+    }
+
+    if (rows & 4) {
+            aoffset1 = aoffset;
+            aoffset2 = aoffset1 + lda;
+            aoffset3 = aoffset2 + lda;
+            aoffset4 = aoffset3 + lda;
+            aoffset += 4 * lda;
+
+        i = (cols >> 3);
+            if (i > 0) {
+               do {
+                    C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
+                    C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
+                    C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
+                    C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs);
+
+                    __builtin_vsx_disassemble_pair(c1, &C1);
+                    __builtin_vsx_disassemble_pair(c2, &C2);
+                    __builtin_vsx_disassemble_pair(c3, &C3);
+                    __builtin_vsx_disassemble_pair(c4, &C4);
+
+                    t1 = vec_perm(c1[0], c2[0], swiz1);
+                    t2 = vec_perm(c1[0], c2[0], swiz2);
+                    t3 = vec_perm(c3[0], c4[0], swiz1);
+                    t4 = vec_perm(c3[0], c4[0], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    if (flip == true) {
+                       t5 = vec_xor(t5, xor_vector);
+                       t6 = vec_xor(t6, xor_vector);
+                       t7 = vec_xor(t7, xor_vector);
+                       t8 = vec_xor(t8, xor_vector);
+                    }
+                    vec_xst(t5, 0, vecOffset);
+                    vec_xst(t6, 0, vecOffset+16);
+                    vec_xst(t7, 0, vecOffset+32);
+                    vec_xst(t8, 0, vecOffset+48);
+
+                    t1 = vec_perm(c1[1], c2[1], swiz1);
+                    t2 = vec_perm(c1[1], c2[1], swiz2);
+                    t3 = vec_perm(c3[1], c4[1], swiz1);
+                    t4 = vec_perm(c3[1], c4[1], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    if (flip == true) {
+                       t5 = vec_xor(t5, xor_vector);
+                       t6 = vec_xor(t6, xor_vector);
+                       t7 = vec_xor(t7, xor_vector);
+                       t8 = vec_xor(t8, xor_vector);
+                    }
+                    vec_xst(t5, 0, vecOffset+64);
+                    vec_xst(t6, 0, vecOffset+80);
+                    vec_xst(t7, 0, vecOffset+96);
+                    vec_xst(t8, 0, vecOffset+112);
+
+                    aoffset1 += lda;
+                    aoffset2 += lda;
+                    aoffset3 += lda;
+                    aoffset4 += lda;
+                    vecOffset += 128;
+                    i--;
+               } while(i > 0);
+            }
+        }
+        if (rows & 3) {
+            aoffset1 = aoffset;
+            aoffset2 = aoffset1 + lda;
+            aoffset3 = aoffset2 + lda;
+            i = (cols >> 3);
+        if (i > 0) {
+                do {
+                    switch(rows) {
+                        case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs);
+                                __builtin_vsx_disassemble_pair(c3, &C3);
+                        case 2: C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs);
+                                __builtin_vsx_disassemble_pair(c2, &C2);
+                        case 1: C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs);
+                                __builtin_vsx_disassemble_pair(c1, &C1);
+                                break;
+                    }
+                    t1 = vec_perm(c1[0], c2[0], swiz1);
+                    t2 = vec_perm(c1[0], c2[0], swiz2);
+                    t3 = vec_perm(c3[0], c4[0], swiz1);
+                    t4 = vec_perm(c3[0], c4[0], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    if (flip == true) {
+                       t5 = vec_xor(t5, xor_vector);
+                       t6 = vec_xor(t6, xor_vector);
+                       t7 = vec_xor(t7, xor_vector);
+                       t8 = vec_xor(t8, xor_vector);
+                    }
+                    vec_xst(t5, 0, vecOffset);
+                    vec_xst(t6, 0, vecOffset+16);
+                    vec_xst(t7, 0, vecOffset+32);
+                    vec_xst(t8, 0, vecOffset+48);
+
+                    t1 = vec_perm(c1[1], c2[1], swiz1);
+                    t2 = vec_perm(c1[1], c2[1], swiz2);
+                    t3 = vec_perm(c3[1], c4[1], swiz1);
+                    t4 = vec_perm(c3[1], c4[1], swiz2);
+                    t5 = vec_perm(t1, t3, swiz3);
+                    t6 = vec_perm(t1, t3, swiz4);
+                    t7 = vec_perm(t2, t4, swiz3);
+                    t8 = vec_perm(t2, t4, swiz4);
+                    if (flip == true) {
+                       t5 = vec_xor(t5, xor_vector);
+                       t6 = vec_xor(t6, xor_vector);
+                       t7 = vec_xor(t7, xor_vector);
+                       t8 = vec_xor(t8, xor_vector);
+                    }
+                    vec_xst(t5, 0, vecOffset+64);
+                    vec_xst(t6, 0, vecOffset+80);
+                    vec_xst(t7, 0, vecOffset+96);
+                    vec_xst(t8, 0, vecOffset+112);
+
+                    aoffset1 += lda;
+                    aoffset2 += lda;
+                    aoffset3 += lda;
+                    vecOffset += 128;
+                    i--;
+               } while(i > 0);
+            }
+        }
+    }
+
+    void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t mc, nc, mp, np;
+        int m_rem = MIN(m - m0, 8);
+        int n_rem = MIN(n - n0, 8);
+        // TO-DO: KERNEL_16x8 and KERNEL_8x16 are having some performance
+        // issues. After resolving them, below code will be enabled.
+        /*if (m_rem >= 16 && n_rem >= 8) {
+            mc = 16;
+            nc = 8;
+            gemm<16,8>(m0, m, n0, n);
+        } else if(m_rem >= 8 && n_rem >= 16) {
+            mc = 8;
+            nc = 16;
+            gemm<8,16>(m0, m, n0, n);
+        }*/
+        if (m_rem >= 8 && n_rem >= 8) {
+            mc = 8;
+            nc = 8;
+            gemm<8,8>(m0, m, n0, n);
+        } else if (m_rem >= 4 && n_rem >= 8) {
+            mc = 4;
+            nc = 8;
+            gemm<4,8>(m0, m, n0, n);
+        } else if (m_rem >= 8 && n_rem >= 4) {
+            mc = 8;
+            nc = 4;
+            gemm<8,4>(m0, m, n0, n);
+        } else if (m_rem >= 4 && n_rem >= 4) {
+            mc = 4;
+            nc = 4;
+            gemm_small<4, 4>(m0, m, n0, n);
+        } else if ((m_rem < 4) && (n_rem > 4)) {
+            nc = 4;
+            switch(m_rem) {
+                case 1:
+                    mc = 1;
+                    gemm_small<1, 4>(m0, m, n0, n);
+                    break;
+                case 2:
+                    mc = 2;
+                    gemm_small<2, 4>(m0, m, n0, n);
+                    break;
+                case 3:
+                    mc = 3;
+                    gemm_small<3, 4>(m0, m, n0, n);
+                    break;
+                default:
+                    return;
+            }
+        } else if ((m_rem > 4) && (n_rem < 4)) {
+            mc = 4;
+            switch(n_rem) {
+                case 1:
+                    nc = 1;
+                    gemm_small<4, 1>(m0, m, n0, n);
+                    break;
+                case 2:
+                    nc = 2;
+                    gemm_small<4, 2>(m0, m, n0, n);
+                    break;
+                case 3:
+                    nc = 3;
+                    gemm_small<4, 3>(m0, m, n0, n);
+                    break;
+                default:
+                    return;
+            }
+        } else {
+            switch((m_rem << 4) | n_rem) {
+                case 0x43:
+                    mc = 4;
+                    nc = 3;
+                    gemm_small<4, 3>(m0, m, n0, n);
+                    break;
+                case 0x42:
+                    mc = 4;
+                    nc = 2;
+                    gemm_small<4, 2>(m0, m, n0, n);
+                    break;
+                case 0x41:
+                    mc = 4;
+                    nc = 1;
+                    gemm_small<4, 1>(m0, m, n0, n);
+                    break;
+                case 0x34:
+                    mc = 3;
+                    nc = 4;
+                    gemm_small<3, 4>(m0, m, n0, n);
+                    break;
+                case 0x33:
+                    mc = 3;
+                    nc = 3;
+                    gemm_small<3, 3>(m0, m, n0, n);
+                    break;
+                case 0x32:
+                    mc = 3;
+                    nc = 2;
+                    gemm_small<3, 2>(m0, m, n0, n);
+                    break;
+                case 0x31:
+                    mc = 3;
+                    nc = 1;
+                    gemm_small<3, 1>(m0, m, n0, n);
+                    break;
+                case 0x24:
+                    mc = 2;
+                    nc = 4;
+                    gemm_small<2, 4>(m0, m, n0, n);
+                    break;
+                case 0x23:
+                    mc = 2;
+                    nc = 3;
+                    gemm_small<2, 3>(m0, m, n0, n);
+                    break;
+                case 0x22:
+                    mc = 2;
+                    nc = 2;
+                    gemm_small<2, 2>(m0, m, n0, n);
+                    break;
+                case 0x21:
+                    mc = 2;
+                    nc = 1;
+                    gemm_small<2, 1>(m0, m, n0, n);
+                    break;
+                case 0x14:
+                    mc = 1;
+                    nc = 4;
+                    gemm_small<1, 4>(m0, m, n0, n);
+                    break;
+                case 0x13:
+                    mc = 1;
+                    nc = 3;
+                    gemm_small<1, 3>(m0, m, n0, n);
+                    break;
+                case 0x12:
+                    mc = 1;
+                    nc = 2;
+                    gemm_small<1, 2>(m0, m, n0, n);
+                    break;
+                case 0x11:
+                    mc = 1;
+                    nc = 1;
+                    gemm_small<1, 1>(m0, m, n0, n);
+                    break;
+                default:
+                    return;
+            }
+        }
+        mp = m0 + (m - m0) / mc * mc;
+        np = n0 + (n - n0) / nc * nc;
+        mnpack(mp, m, n0, np);
+        mnpack(m0, m, np, n);
+    }
+
+    void KERNEL_4x8(int64_t ii, int64_t jj) {
+        vec_t vec_A[8], vec_B[16] = {0};
+        acc_t acc_0, acc_1;
+        std::array comparray;
+        vector float fin_res[8] = {0};
+        vector float vs[8] = {0};
+        for (int l = 0; l < k; l++) {
+            __builtin_mma_xxsetaccz(&acc_0);
+            __builtin_mma_xxsetaccz(&acc_1);
+            packNormal((A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false);
+            packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
+            for(int x = 0; x < 8; x++) {
+                __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
+                __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]);
+            }
+            for (int I = 0; I<4; I++) {
+                for (int J = 0; J<4; J++) {
+                    *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
+                    *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
+                }
+            }
+            auto aoffset = A+(ii*lda)+l;
+            for (int i = 0; i < 4; i++) {
+                comparray[i] = 0;
+                int ca = 0;
+                const int8_t *at = aoffset->qs;
+                for (int j = 0; j < 32; j++)
+                    ca += (int)*at++;
+                comparray[i] = ca;
+                aoffset += lda;
+            }
+            compute<4>(&acc_0, 0, 0, comparray, vs, fin_res);
+            compute<4>(&acc_1, 0, 4, comparray, vs, fin_res);
+        }
+        save_res<4, 4>(ii, jj, 0, fin_res);
+        save_res<4, 4>(ii, jj+4, 4, fin_res);
+    }
+
+    void KERNEL_8x4(int64_t ii, int64_t jj) {
+        vec_t vec_A[16], vec_B[8] = {0};
+        acc_t acc_0, acc_1;
+        std::array comparray;
+        vector float fin_res[8] = {0};
+        vector float vs[8] = {0};
+        for (int l = 0; l < k; l++) {
+            __builtin_mma_xxsetaccz(&acc_0);
+            __builtin_mma_xxsetaccz(&acc_1);
+            packNormal((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
+            packNormal((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true);
+            for(int x = 0; x < 8; x++) {
+                __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
+                __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
+            }
+            for (int I = 0; I<8; I++) {
+                for (int J = 0; J<4; J++) {
+                    *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
+                }
+            }
+            auto aoffset = A+(ii*lda)+l;
+            for (int i = 0; i < 8; i++) {
+                comparray[i] = 0;
+                int ca = 0;
+                const int8_t *at = aoffset->qs;
+                for (int j = 0; j < 32; j++)
+                    ca += (int)*at++;
+                comparray[i] = ca;
+                aoffset += lda;
+            }
+            compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
+            compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
+        }
+        save_res<4, 4>(ii, jj, 0, fin_res);
+        save_res<4, 4>(ii+4, jj, 4, fin_res);
+    }
+
+    void KERNEL_8x8(int64_t ii, int64_t jj) {
+        vec_t vec_A[16], vec_B[16] = {0};
+        acc_t acc_0, acc_1, acc_2, acc_3;
+        std::array comparray;
+        vector float fin_res[16] = {0};
+        vector float vs[16] = {0};
+        for (int l = 0; l < k; l++) {
+            __builtin_mma_xxsetaccz(&acc_0);
+            __builtin_mma_xxsetaccz(&acc_1);
+            __builtin_mma_xxsetaccz(&acc_2);
+            __builtin_mma_xxsetaccz(&acc_3);
+            packNormal((A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false);
+            packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true);
+            for(int x = 0; x < 8; x++) {
+                __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
+                __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]);
+                __builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]);
+                __builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]);
+            }
+            for (int I = 0; I<8; I++) {
+                for (int J = 0; J<4; J++) {
+                    *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d));
+                    *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d));
+                }
+            }
+            auto aoffset = A+(ii*lda)+l;
+            for (int i = 0; i < 8; i++) {
+                comparray[i] = 0;
+                int ca = 0;
+                const int8_t *at = aoffset->qs;
+                for (int j = 0; j < 32; j++)
+                    ca += (int)*at++;
+                comparray[i] = ca;
+                aoffset += lda;
+            }
+            compute<8>(&acc_0, 0, 0, comparray, vs, fin_res);
+            compute<8>(&acc_1, 4, 4, comparray, vs, fin_res);
+            compute<8>(&acc_2, 0, 8, comparray, vs, fin_res);
+            compute<8>(&acc_3, 4, 12, comparray, vs, fin_res);
+        }
+        save_res<4, 4>(ii, jj, 0, fin_res);
+        save_res<4, 4>(ii+4, jj, 4, fin_res);
+        save_res<4, 4>(ii, jj+4, 8, fin_res);
+        save_res<4, 4>(ii+4, jj+4, 12, fin_res);
+    }
+
+    template
+    void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t ytiles = (m - m0) / RM;
+        int64_t xtiles = (n - n0) / RN;
+        int64_t tiles = xtiles * ytiles;
+        int64_t duty = (tiles + nth - 1) / nth;
+        int64_t start = duty * ith;
+        int64_t end = start + duty;
+        vec_t vec_A[8], vec_B[8] = {0};
+        vector signed int vec_C[4];
+        acc_t acc_0;
+
+        if (end > tiles)
+            end = tiles;
+        for (int64_t job = start; job < end; ++job) {
+            int64_t ii = m0 + job / xtiles * RM;
+            int64_t jj = n0 + job % xtiles * RN;
+            std::array comparray;
+            vector float res[4] = {0};
+            vector float fin_res[4] = {0};
+            vector float vs[4] = {0};
+            vector float CA[4] = {0};
+            __builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value
+            __builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value
+            for (int l = 0; l < k; l++) {
+                __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead
+                __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead
+                __builtin_mma_xxsetaccz(&acc_0);
+                packNormal((A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false);
+                packNormal((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true);
+                for(int x = 0; x < 8; x+=4) {
+                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]);
+                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]);
+                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]);
+                    __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]);
+                }
+                for (int I = 0; Id) * unhalf((B+((jj+J)*ldb)+l)->d));
+                    }
+                }
+                __builtin_mma_disassemble_acc(vec_C, &acc_0);
+                auto aoffset = A+(ii*lda)+l;
+                for (int i = 0; i < RM; i++) {
+                    comparray[i] = 0;
+                    int ca = 0;
+                    const int8_t *at = aoffset->qs;
+                    for (int j = 0; j < 32; j++)
+                        ca += (int)*at++;
+                    comparray[i] = ca;
+                    aoffset += lda;
+                }
+
+                for (int i = 0; i < RM; i++) {
+                    CA[i] = vec_splats((float)(((double)comparray[i]) * -128.0));
+                    res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]);
+                    fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]);
+                }
+            }
+            save_res(ii, jj, 0, fin_res);
+        }
+    }
+
+    template
+    inline void kernel(int64_t ii, int64_t jj) {
+       if constexpr(RM == 4 && RN == 8) {
+          KERNEL_4x8(ii,jj);
+       } else if constexpr(RM == 8 && RN == 4) {
+          KERNEL_8x4(ii,jj);
+       } else if constexpr(RM == 8 && RN == 8) {
+          KERNEL_8x8(ii,jj);
+       } else {
+          static_assert(false, "RN/RM values not supported");
+       }
+    }
+
+    template 
+    NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
+        int64_t ytiles = (m - m0) / RM;
+        int64_t xtiles = (n - n0) / RN;
+        int64_t tiles = xtiles * ytiles;
+        int64_t duty = (tiles + nth - 1) / nth;
+        int64_t start = duty * ith;
+        int64_t end = start + duty;
+        if (end > tiles)
+            end = tiles;
+        for (int64_t job = start; job < end; ++job) {
+            int64_t ii = m0 + job / xtiles * RM;
+            int64_t jj = n0 + job % xtiles * RN;
+            kernel(ii, jj);
+        }
+    }
+
+    const TA *const A;
+    const TB *const B;
+    TC *C;
+    TA *At;
+    TB *Bt;
+    const int64_t k;
+    const int64_t lda;
+    const int64_t ldb;
+    const int64_t ldc;
+    const int ith;
+    const int nth;
+};
+
 template 
 class tinyBLAS_PPC {
   public:
@@ -1070,13 +1769,17 @@ class tinyBLAS_PPC {
 
     void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
 
-    void READ_BLOCK(const float* a, int64_t lda, int rows, int cols, float* vec) {
+    template
+    void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) {
         int64_t i, j;
-        float *aoffset = NULL, *boffset = NULL;
-        float *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
-        float *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
-
-        aoffset = const_cast(a);
+        TA *aoffset = NULL, *boffset = NULL;
+        TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
+        TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
+        __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
+        VA c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0};
+        VA c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0};
+        VA t1, t2, t3, t4, t5, t6, t7, t8;
+        aoffset = const_cast(a);
         boffset = vec;
         j = (rows >> 3);
         if (j > 0) {
@@ -1092,9 +1795,6 @@ class tinyBLAS_PPC {
                 aoffset += 8 * lda;
                 i = (cols >> 3);
                 if (i > 0) {
-                    __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
-                    vector float c1[2], c2[2], c3[2], c4[2], c5[2], c6[2], c7[2], c8[2];
-                    vector float t1, t2, t3, t4, t5, t6, t7, t8;
                     do {
                         C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
                         C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
@@ -1174,21 +1874,19 @@ class tinyBLAS_PPC {
                     } while(i > 0);
                 }
                 if (cols & 4) {
-                    vector float c1, c2, c3, c4, c5, c6, c7, c8;
-                    vector float t1, t2, t3, t4, t5, t6, t7, t8;
-                    c1 = vec_xl(0, aoffset1);
-                    c2 = vec_xl(0, aoffset2);
-                    c3 = vec_xl(0, aoffset3);
-                    c4 = vec_xl(0, aoffset4);
-                    c5 = vec_xl(0, aoffset5);
-                    c6 = vec_xl(0, aoffset6);
-                    c7 = vec_xl(0, aoffset7);
-                    c8 = vec_xl(0, aoffset8);
-
-                    t1 = vec_mergeh(c1, c2);
-                    t2 = vec_mergeh(c3, c4);
-                    t3 = vec_mergeh(c5, c6);
-                    t4 = vec_mergeh(c7, c8);
+                    c1[0] = vec_xl(0, aoffset1);
+                    c2[0] = vec_xl(0, aoffset2);
+                    c3[0] = vec_xl(0, aoffset3);
+                    c4[0] = vec_xl(0, aoffset4);
+                    c5[0] = vec_xl(0, aoffset5);
+                    c6[0] = vec_xl(0, aoffset6);
+                    c7[0] = vec_xl(0, aoffset7);
+                    c8[0] = vec_xl(0, aoffset8);
+
+                    t1 = vec_mergeh(c1[0], c2[0]);
+                    t2 = vec_mergeh(c3[0], c4[0]);
+                    t3 = vec_mergeh(c5[0], c6[0]);
+                    t4 = vec_mergeh(c7[0], c8[0]);
                     t5 = vec_xxpermdi(t1, t2, 0);
                     t6 = vec_xxpermdi(t3, t4, 0);
                     t7 = vec_xxpermdi(t1, t2, 3);
@@ -1198,10 +1896,10 @@ class tinyBLAS_PPC {
                     vec_xst(t7, 0, boffset+8);
                     vec_xst(t8, 0, boffset+12);
 
-                    t1 = vec_mergel(c1, c2);
-                    t2 = vec_mergel(c3, c4);
-                    t3 = vec_mergel(c5, c6);
-                    t4 = vec_mergel(c7, c8);
+                    t1 = vec_mergel(c1[0], c2[0]);
+                    t2 = vec_mergel(c3[0], c4[0]);
+                    t3 = vec_mergel(c5[0], c6[0]);
+                    t4 = vec_mergel(c7[0], c8[0]);
                     t5 = vec_xxpermdi(t1, t2, 0);
                     t6 = vec_xxpermdi(t3, t4, 0);
                     t7 = vec_xxpermdi(t1, t2, 3);
@@ -1223,9 +1921,6 @@ class tinyBLAS_PPC {
             aoffset += 4 * lda;
             i = (cols >> 3);
             if (i > 0) {
-                __vector_pair C1, C2, C3, C4;
-                vector float c1[2], c2[2], c3[2], c4[2];
-                vector float t1, t2, t3, t4, t5, t6, t7, t8;
                 do {
                     C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
                     C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
@@ -1272,22 +1967,20 @@ class tinyBLAS_PPC {
             }
 
             if (cols & 4) {
-                vector float c1, c2, c3, c4;
-                vector float t1, t2, t3, t4;
-                c1 = vec_xl(0, aoffset1);
-                c2 = vec_xl(0, aoffset2);
-                c3 = vec_xl(0, aoffset3);
-                c4 = vec_xl(0, aoffset4);
-
-                t1 = vec_mergeh(c1, c2);
-                t2 = vec_mergeh(c3, c4);
+                c1[0] = vec_xl(0, aoffset1);
+                c2[0] = vec_xl(0, aoffset2);
+                c3[0] = vec_xl(0, aoffset3);
+                c4[0] = vec_xl(0, aoffset4);
+
+                t1 = vec_mergeh(c1[0], c2[0]);
+                t2 = vec_mergeh(c3[0], c4[0]);
                 t3 = vec_xxpermdi(t1, t2, 0);
                 t4 = vec_xxpermdi(t1, t2, 3);
                 vec_xst(t3, 0, boffset);
                 vec_xst(t4, 0, boffset+4);
 
-                t1 = vec_mergel(c1, c2);
-                t2 = vec_mergel(c3, c4);
+                t1 = vec_mergel(c1[0], c2[0]);
+                t2 = vec_mergel(c3[0], c4[0]);
                 t3 = vec_xxpermdi(t1, t2, 0);
                 t4 = vec_xxpermdi(t1, t2, 3);
                 vec_xst(t3, 0, boffset+8);
@@ -1299,21 +1992,19 @@ class tinyBLAS_PPC {
             aoffset2 = aoffset1 + lda;
             aoffset3 = aoffset2 + lda;
             if (cols & 4) {
-                vector float c1, c2, c3, c4 = {0};
-                vector float t1, t2, t3, t4;
-                c1 = vec_xl(0, aoffset1);
-                c2 = vec_xl(0, aoffset2);
-                c3 = vec_xl(0, aoffset3);
-
-                t1 = vec_mergeh(c1, c2);
-                t2 = vec_mergeh(c3, c4);
+                c1[0] = vec_xl(0, aoffset1);
+                c2[0] = vec_xl(0, aoffset2);
+                c3[0] = vec_xl(0, aoffset3);
+
+                t1 = vec_mergeh(c1[0], c2[0]);
+                t2 = vec_mergeh(c3[0], c4[0]);
                 t3 = vec_xxpermdi(t1, t2, 0);
                 t4 = vec_xxpermdi(t1, t2, 3);
                 vec_xst(t3, 0, boffset);
                 vec_xst(t4, 0, boffset+4);
 
-                t1 = vec_mergel(c1, c2);
-                t2 = vec_mergel(c3, c4);
+                t1 = vec_mergel(c1[0], c2[0]);
+                t2 = vec_mergel(c3[0], c4[0]);
                 t3 = vec_xxpermdi(t1, t2, 0);
                 t4 = vec_xxpermdi(t1, t2, 3);
                 vec_xst(t3, 0, boffset+8);
@@ -1321,14 +2012,13 @@ class tinyBLAS_PPC {
             }
         }
     }
-
     void KERNEL_4x4(int64_t ii, int64_t jj) {
         vec_t vec_A[4], vec_B[4], vec_C[4];
         acc_t acc_0;
         __builtin_mma_xxsetaccz(&acc_0);
         for (int l = 0; l < k; l+=4) {
-            READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
-            READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
+            packTranspose(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
+            packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
             __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
             __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
             __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
@@ -1343,8 +2033,8 @@ class tinyBLAS_PPC {
         __builtin_mma_xxsetaccz(&acc_0);
         __builtin_mma_xxsetaccz(&acc_1);
         for (int64_t l = 0; l < k; l+=4) {
-            READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
-            READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
+            packTranspose(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A);
+            packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (TA*)vec_B);
             __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
             __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
             __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
@@ -1364,8 +2054,8 @@ class tinyBLAS_PPC {
         __builtin_mma_xxsetaccz(&acc_0);
         __builtin_mma_xxsetaccz(&acc_1);
         for (int64_t l = 0; l < k; l+=4) {
-            READ_BLOCK(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
-            READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
+            packTranspose(A+(ii*lda)+l, lda, 8, 4, (TA*)vec_A);
+            packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
             __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
             __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
             __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
@@ -1387,8 +2077,8 @@ class tinyBLAS_PPC {
         __builtin_mma_xxsetaccz(&acc_2);
         __builtin_mma_xxsetaccz(&acc_3);
         for (int l = 0; l < k; l+=8) {
-            READ_BLOCK(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
-            READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
+            packTranspose(A+(ii*lda)+l, lda, 8, 8, (TA*)vec_A);
+            packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (TA*)vec_B);
             for(int x = 0; x < 16; x+=2) {
                 __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
                 __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
@@ -1571,15 +2261,15 @@ class tinyBLAS_PPC {
             vec_t vec_A[4], vec_B[4];
             for (int l=0; l= 4 && RM == 1) {
-                    float* a = const_cast(A+(ii)*lda+l);
-                    READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
+                    TA* a = const_cast(A+(ii)*lda+l);
+                    packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B);
                     vec_A[0] = (vec_t)vec_xl(0,a);
-                    vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
-                    vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
-                    vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
+                    vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1));
+                    vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2));
+                    vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3));
                 } else {
-                    READ_BLOCK(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
-                    READ_BLOCK(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
+                    packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A);
+                    packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B);
                 }
                 __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
                 __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
@@ -1589,7 +2279,7 @@ class tinyBLAS_PPC {
             __builtin_mma_disassemble_acc(vec_C, &acc_0);
             for (int I = 0; I < RM; I++) {
                 for (int J = 0; J < RN; J++) {
-                    *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
+                    *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J);
                 }
             }
        }
@@ -1812,6 +2502,20 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
             params->ith, params->nth};
         tb.matmul(m, n);
         return true;
+
+#elif defined(__MMA__)
+        if (n < 8 && n != 4)
+           return false;
+        if (m < 8 && m != 4)
+           return false;
+        tinyBLAS_Q0_PPC tb{
+            k, (const block_q8_0 *)A, lda,
+            (const block_q8_0 *)B, ldb,
+            (float *)C, ldc,
+            params->ith, params->nth};
+        tb.matmul(m, n);
+        return true;
+
 #else
         return false;
 #endif
diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu
index 2f42b8a9538..aafbaf803b4 100644
--- a/ggml/src/ggml-cuda/concat.cu
+++ b/ggml/src/ggml-cuda/concat.cu
@@ -124,7 +124,7 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE)
           uint64_t   nb1,
           uint64_t   nb2,
           uint64_t   nb3){
-    static_assert(dim >= 0 && dim <= 3);
+    static_assert(dim >= 0 && dim <= 3, "dim must be in [0, 3]");
 
     const int64_t i3 = blockIdx.z;
     const int64_t i2 = blockIdx.y;
diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu
index 3896f956d73..5b0dfacefc9 100644
--- a/ggml/src/ggml-cuda/convert.cu
+++ b/ggml/src/ggml-cuda/convert.cu
@@ -680,6 +680,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
             return dequantize_row_iq3_s_cuda;
         case GGML_TYPE_F16:
             return convert_unary_cuda;
+        case GGML_TYPE_BF16:
+            return convert_unary_cuda;
         default:
             return nullptr;
     }
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index c180adc84b8..1dac397c4b0 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -37,6 +37,7 @@
 #include "ggml-cuda/unary.cuh"
 #include "ggml-cuda/upscale.cuh"
 #include "ggml-cuda/wkv6.cuh"
+#include "ggml-cuda/gla.cuh"
 
 #include 
 #include 
@@ -1728,7 +1729,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
 static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
 
-    bool use_mul_mat_vec   = src0->type == GGML_TYPE_F16
+    bool use_mul_mat_vec   = (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
         && src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
     bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
@@ -2167,6 +2168,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_RWKV_WKV6:
             ggml_cuda_op_rwkv_wkv6(ctx, dst);
             break;
+        case GGML_OP_GATED_LINEAR_ATTN:
+            ggml_cuda_op_gated_linear_attn(ctx, dst);
+            break;
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
             ggml_cuda_cross_entropy_loss_back(ctx, dst);
             break;
@@ -2285,6 +2289,66 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
 }
 
 #ifdef USE_CUDA_GRAPH
+static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
+    std::vector & ggml_cuda_cpy_fn_ptrs, bool use_cuda_graph) {
+
+    // Loop over nodes in GGML graph to obtain info needed for CUDA graph
+    cuda_ctx->cuda_graph->updated_kernel_arg.clear();
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        ggml_tensor * node = cgraph->nodes[i];
+
+        if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
+            continue;
+        }
+
+        if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
+            use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
+#endif
+        }
+
+        if (node->op == GGML_OP_MUL_MAT_ID) {
+            use_cuda_graph = false; // This node type is not supported by CUDA graph capture
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
+#endif
+        }
+
+        if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
+            // disable CUDA graphs for batch size > 1 for now.
+            // Changes in batch size or context size can cause changes to the grid size of some kernels.
+            use_cuda_graph = false;
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
+#endif
+        }
+
+        if (node->op == GGML_OP_CPY) {
+            // store the copy op parameter which changes with each token.
+            cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
+            // store a pointer to each copy op CUDA kernel to identify it later
+            void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
+            if (!ptr) {
+                use_cuda_graph = false;
+#ifndef NDEBUG
+                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
+#endif
+            } else {
+                if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
+                    ggml_cuda_cpy_fn_ptrs.push_back(ptr);
+                }
+            }
+        }
+
+        if (!use_cuda_graph) {
+            break;
+        }
+    }
+
+    return use_cuda_graph;
+}
+
 static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
     graph_node_properties->node_address = node->data;
     graph_node_properties->node_op = node->op;
@@ -2335,149 +2399,105 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
 
     return true;
 }
-#endif
 
-static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
-    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vector & ggml_cuda_cpy_fn_ptrs, bool cuda_graph_update_required) {
 
-    ggml_cuda_set_device(cuda_ctx->device);
+    if (cuda_graph_update_required) {
+        // Extract nodes from graph
+        // First call with null argument gets number of nodes in graph
+        CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
+        // Subsequent call with non-null argument gets nodes
+        cuda_ctx->cuda_graph->nodes.clear();
+        cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
+        cuda_ctx->cuda_graph->params.clear();
+        cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
+        if (cuda_ctx->cuda_graph->num_nodes > 0) {
+            CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
 
-#ifdef USE_CUDA_GRAPH
-    static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
-
-    // Objects required for CUDA Graph
-    if (cuda_ctx->cuda_graph == nullptr) {
-        cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
+            // Loop over nodes, and extract kernel parameters from each node
+            for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
+                cudaGraphNodeType node_type;
+                CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
+                if (node_type == cudaGraphNodeTypeKernel) {
+                    cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
+                    if (stat == cudaErrorInvalidDeviceFunction) {
+                        // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
+                        // We don't need to update blas nodes, so clear error and move on.
+                        cudaGetLastError();
+                    } else {
+                        GGML_ASSERT(stat == cudaSuccess);
+                    }
+                }
+            }
+        }
+    } else {
+        // One of the arguments to the copy kernel is updated for each token, hence we need to
+        // replace that argument with the updated value in the CUDA graph
+        // on update steps, the live parameters will already be captured
+        int k = 0;
+        for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
+            if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
+                char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
+                cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
+                CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
+            }
+        }
     }
+}
+
+static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
 
-    bool use_cuda_graph = true;
     bool cuda_graph_update_required = false;
-    // vector of pointers to CUDA cpy kernels, which are required to identify
-    // kernel parameters which need updated in the graph for each token
-    std::vector ggml_cuda_cpy_fn_ptrs;
 
-    if (cuda_ctx->cuda_graph->graph == nullptr) {
-        if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
-            cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
-#ifndef NDEBUG
-            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
-#endif
-        }
+    if (cuda_ctx->cuda_graph->instance == nullptr) {
+        cuda_graph_update_required = true;
     }
 
-    // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
-    // or previous graph capture failure.
-    // Also disable for multi-gpu for now. TO DO investigate
-    if (disable_cuda_graphs_due_to_env
-        || cuda_ctx->cuda_graph->disable_due_to_gpu_arch
-        || cuda_ctx->cuda_graph->disable_due_to_too_many_updates
-        || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
-        use_cuda_graph = false;
+    // Check if the graph size has changed
+    if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
+        cuda_graph_update_required = true;
+        cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
     }
 
-    if (use_cuda_graph) {
-        if (cuda_ctx->cuda_graph->instance == nullptr) {
-            cuda_graph_update_required = true;
+    // Loop over nodes in GGML graph to determine if CUDA graph update is required
+    // and store properties to allow this comparison for the next token
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        bool has_matching_properties = true;
+        if (!cuda_graph_update_required) {
+            has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
         }
-
-        // Check if the graph size has changed
-        if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
+        if (!has_matching_properties) {
             cuda_graph_update_required = true;
-            cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
-        }
-
-        // Loop over nodes in GGML graph to determine if CUDA graph update is required
-        // and store properties to allow this comparison for the next token
-        for (int i = 0; i < cgraph->n_nodes; i++) {
-            bool has_matching_properties = true;
-            if (!cuda_graph_update_required) {
-                has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
-            }
-            if (!has_matching_properties) {
-                cuda_graph_update_required = true;
-            }
-            set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
         }
+        set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
+    }
 
-        // Loop over nodes in GGML graph to obtain info needed for CUDA graph
-        cuda_ctx->cuda_graph->updated_kernel_arg.clear();
-        for (int i = 0; i < cgraph->n_nodes; i++) {
-            ggml_tensor * node = cgraph->nodes[i];
-
-            if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
-                continue;
-            }
-
-            if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
-                use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
-#ifndef NDEBUG
-                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
-#endif
-            }
-
-            if (node->op == GGML_OP_MUL_MAT_ID) {
-                use_cuda_graph = false; // This node type is not supported by CUDA graph capture
-#ifndef NDEBUG
-                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
-#endif
-            }
+    return cuda_graph_update_required;
+}
 
-            if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
-                // disable CUDA graphs for batch size > 1 for now.
-                // Changes in batch size or context size can cause changes to the grid size of some kernels.
-                use_cuda_graph = false;
-#ifndef NDEBUG
-                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
-#endif
-            }
+static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
 
-            if (node->op == GGML_OP_CPY) {
-                // store the copy op parameter which changes with each token.
-                cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
-                // store a pointer to each copy op CUDA kernel to identify it later
-                void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
-                if (!ptr) {
-                    use_cuda_graph = false;
+    cudaGraphExecUpdateResultInfo result_info;
+    cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
+    if (stat == cudaErrorGraphExecUpdateFailure) {
 #ifndef NDEBUG
-                    GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
+        GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
 #endif
-                } else {
-                    if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
-                        ggml_cuda_cpy_fn_ptrs.push_back(ptr);
-                    }
-                }
-            }
-
-            if (!use_cuda_graph) {
-                break;
-            }
-        }
-
-        // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
-        if (use_cuda_graph && cuda_graph_update_required) {
-            cuda_ctx->cuda_graph->number_consecutive_updates++;
-        } else {
-            cuda_ctx->cuda_graph->number_consecutive_updates = 0;
-        }
-
-        if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
-            cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
-#ifndef NDEBUG
-            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
-#endif
-        }
-    }
-
-    if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
-        CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
+        // The pre-existing graph exec cannot be updated due to violated constraints
+        // so instead clear error and re-instantiate
+        cudaGetLastError();
+        CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
+        cuda_ctx->cuda_graph->instance = nullptr;
+        CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
+    } else {
+        GGML_ASSERT(stat == cudaSuccess);
     }
+}
+#endif
 
-#else
-    bool use_cuda_graph = false;
-    bool cuda_graph_update_required = false;
-#endif // USE_CUDA_GRAPH
-
-    bool graph_evaluated_or_captured = false;
+static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
+   [[maybe_unused]] std::vector & ggml_cuda_cpy_fn_ptrs,  bool & graph_evaluated_or_captured, bool & use_cuda_graph,
+    bool & cuda_graph_update_required) {
 
     while (!graph_evaluated_or_captured) {
         // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
@@ -2515,19 +2535,8 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
                 CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
                 cuda_ctx->cuda_graph->graph = nullptr;
             }
-            CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
 
-#if 0
-            if (disable_cuda_graphs_due_to_failed_capture) {
-                use_cuda_graph = false;
-                cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
-#ifndef NDEBUG
-                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
-#endif
-            } else {
-                graph_evaluated_or_captured = true; // CUDA graph has been captured
-            }
-#endif
+            CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
             graph_evaluated_or_captured = true; // CUDA graph has been captured
         } else {
             graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
@@ -2540,72 +2549,91 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
         }
 
         // Perform update to graph (if required for this token), and change copy parameter (required for every token)
+        maintain_cuda_graph(cuda_ctx, ggml_cuda_cpy_fn_ptrs, cuda_graph_update_required);
 
-        if (cuda_graph_update_required) {
-            // Extract nodes from graph
-            // First call with null argument gets number of nodes in graph
-            CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
-            // Subsequent call with non-null argument gets nodes
-            cuda_ctx->cuda_graph->nodes.clear();
-            cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
-            cuda_ctx->cuda_graph->params.clear();
-            cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);
-            if (cuda_ctx->cuda_graph->num_nodes > 0) {
-                CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, cuda_ctx->cuda_graph->nodes.data(), &cuda_ctx->cuda_graph->num_nodes));
-
-                // Loop over nodes, and extract kernel parameters from each node
-                for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
-                    cudaGraphNodeType node_type;
-                    CUDA_CHECK(cudaGraphNodeGetType(cuda_ctx->cuda_graph->nodes[i], &node_type));
-                    if (node_type == cudaGraphNodeTypeKernel) {
-                        cudaError_t stat = cudaGraphKernelNodeGetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]); // Get params using runtime
-                        if (stat == cudaErrorInvalidDeviceFunction) {
-                            // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
-                            // We don't need to update blas nodes, so clear error and move on.
-                            cudaGetLastError();
-                        } else {
-                            GGML_ASSERT(stat == cudaSuccess);
-                        }
-                    }
-                }
-            }
+        // Update graph executable
+        update_cuda_graph_executable(cuda_ctx);
+
+        // Launch graph
+        CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
+#else
+        graph_evaluated_or_captured = true;
+#endif  // USE_CUDA_GRAPH
+    }
+}
+
+static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+    ggml_cuda_set_device(cuda_ctx->device);
+
+    // vector of pointers to CUDA cpy kernels, which are required to identify
+    // kernel parameters which need updated in the graph for each token
+    std::vector ggml_cuda_cpy_fn_ptrs;
+
+#ifdef USE_CUDA_GRAPH
+    static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
+
+    // Objects required for CUDA Graph
+    if (cuda_ctx->cuda_graph == nullptr) {
+        cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
+    }
+
+    bool use_cuda_graph = true;
+    bool cuda_graph_update_required = false;
+
+    if (cuda_ctx->cuda_graph->graph == nullptr) {
+        if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
+            cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
+#ifndef NDEBUG
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
+#endif
         }
+    }
 
-        // One of the arguments to the copy kernel is updated for each token, hence we need to
-        // replace that argument with the updated value in the CUDA graph
-        if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
-            int k = 0;
-            for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
-                if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
-                    char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
-                    cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
-                    CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
-                }
-            }
+    // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
+    // or previous graph capture failure.
+    // Also disable for multi-gpu for now. TO DO investigate
+    if (disable_cuda_graphs_due_to_env
+        || cuda_ctx->cuda_graph->disable_due_to_gpu_arch
+        || cuda_ctx->cuda_graph->disable_due_to_too_many_updates
+        || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
+        use_cuda_graph = false;
+    }
+
+    if (use_cuda_graph) {
+        cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
+
+        use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph,
+                             ggml_cuda_cpy_fn_ptrs, use_cuda_graph);
+
+        // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
+        if (use_cuda_graph && cuda_graph_update_required) {
+            cuda_ctx->cuda_graph->number_consecutive_updates++;
+        } else {
+            cuda_ctx->cuda_graph->number_consecutive_updates = 0;
         }
 
-        // Update graph executable
-        cudaGraphExecUpdateResultInfo result_info;
-        cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
-        if (stat == cudaErrorGraphExecUpdateFailure) {
+        if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
+            cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
 #ifndef NDEBUG
-            GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
 #endif
-            // The pre-existing graph exec cannot be updated due to violated constraints
-            // so instead clear error and re-instantiate
-            cudaGetLastError();
-            CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
-            cuda_ctx->cuda_graph->instance = nullptr;
-            CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
-        } else {
-            GGML_ASSERT(stat == cudaSuccess);
         }
-        // Launch graph
-        CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
+    }
+
+    if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
+        CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
+    }
+
 #else
-        graph_evaluated_or_captured = true;
+    bool use_cuda_graph = false;
+    bool cuda_graph_update_required = false;
 #endif // USE_CUDA_GRAPH
-    }
+
+    bool graph_evaluated_or_captured = false;
+
+    evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, ggml_cuda_cpy_fn_ptrs, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
 
     return GGML_STATUS_SUCCESS;
 }
@@ -2869,6 +2897,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                     case GGML_TYPE_IQ3_XXS:
                     case GGML_TYPE_IQ4_NL:
                     case GGML_TYPE_IQ4_XS:
+                    case GGML_TYPE_BF16:
 #ifdef GGML_USE_MUSA
                         if (a->type == GGML_TYPE_Q3_K) {
                             return false;
@@ -3010,6 +3039,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_LEAKY_RELU:
         case GGML_OP_RWKV_WKV6:
+        case GGML_OP_GATED_LINEAR_ATTN:
             return true;
         case GGML_OP_FLASH_ATTN_EXT: {
 #ifndef FLASH_ATTN_AVAILABLE
diff --git a/ggml/src/ggml-cuda/gla.cu b/ggml/src/ggml-cuda/gla.cu
new file mode 100644
index 00000000000..f7d615a8282
--- /dev/null
+++ b/ggml/src/ggml-cuda/gla.cu
@@ -0,0 +1,93 @@
+#include "common.cuh"
+#include "gla.cuh"
+
+template
+static __global__ void gated_linear_attn_f32(const int B, const int T, const int C, const int H, const float scale,
+     const float * k, const float * v, const float * r, const float * td, const float * s, float * dst) {
+    const int tid = threadIdx.x;
+    const int bid = blockIdx.x;
+
+    const int head_size = HEAD_SIZE;
+    const int batch_i = bid / H;
+    const int head_i = bid % H;
+    const int state_size = C * head_size;
+    const int n_seq_tokens = T / B;
+
+    float state[head_size];
+    __shared__ float _k[head_size], _r[head_size], _td[head_size];
+
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
+    }
+
+    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
+        __syncthreads();
+        _k[tid] = k[t];
+        _r[tid] = r[t];
+        _td[tid] = td[t];
+        __syncthreads();
+
+        const float _v = v[t];
+        float y = 0;
+        for (int j = 0; j < head_size; j += 4) {
+            const float4 & k = (float4 &)(_k[j]);
+            const float4 & r = (float4 &)(_r[j]);
+            const float4 & td = (float4 &)(_td[j]);
+            float4 & s = (float4 &)(state[j]);
+            float4 kv;
+
+            kv.x = k.x * _v;
+            kv.y = k.y * _v;
+            kv.z = k.z * _v;
+            kv.w = k.w * _v;
+
+            s.x = s.x * td.x + kv.x;
+            s.y = s.y * td.y + kv.y;
+            s.z = s.z * td.z + kv.z;
+            s.w = s.w * td.w + kv.w;
+
+            y += r.x * s.x;
+            y += r.y * s.y;
+            y += r.z * s.z;
+            y += r.w * s.w;
+        }
+        dst[t] = y * scale;
+    }
+
+    #pragma unroll
+    for (int i = 0; i < head_size; i++) {
+        dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
+    }
+}
+
+void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const float * k_d  = (const float *)dst->src[0]->data;
+    const float * v_d  = (const float *)dst->src[1]->data;
+    const float * r_d  = (const float *)dst->src[2]->data;
+    const float * td_d = (const float *)dst->src[3]->data;
+    const float * s_d  = (const float *)dst->src[4]->data;
+
+    const int64_t B = dst->src[4]->ne[1];
+    const int64_t T = dst->src[0]->ne[2];
+    const int64_t C = dst->ne[0];
+    const int64_t H = dst->src[0]->ne[1];
+
+    float scale;
+    memcpy(&scale, (float*)dst->op_params, sizeof(float));
+
+    float * dst_d = (float *)dst->data;
+
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32);
+    GGML_ASSERT(C % H == 0);
+    GGML_ASSERT(C / H == 64 || C / H == 128);
+
+
+    if (C / H == 64) {
+        gated_linear_attn_f32<64><<>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
+    } else {
+        gated_linear_attn_f32<128><<>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
+    }
+}
diff --git a/ggml/src/ggml-cuda/gla.cuh b/ggml/src/ggml-cuda/gla.cuh
new file mode 100644
index 00000000000..2c82ad7dd72
--- /dev/null
+++ b/ggml/src/ggml-cuda/gla.cuh
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmv.cu
index a4b4f6bc10d..ac45f2d17f1 100644
--- a/ggml/src/ggml-cuda/mmv.cu
+++ b/ggml/src/ggml-cuda/mmv.cu
@@ -1,9 +1,9 @@
 #include "common.cuh"
 #include "mmv.cuh"
 
-template 
+template 
 static __global__ void mul_mat_vec(
-        const half * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
+        const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
         const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
     const int64_t row     = blockIdx.x;
     const int64_t channel = blockIdx.z;
@@ -13,7 +13,6 @@ static __global__ void mul_mat_vec(
     y   +=  channel               *stride_channel_y;
     dst +=  channel               *stride_channel_dst;
 
-    const half2  * x2 = (const half2  *) x;
     const float2 * y2 = (const float2 *) y;
 
     extern __shared__ char data_mmv[];
@@ -28,28 +27,44 @@ static __global__ void mul_mat_vec(
 
     float sumf;
 
-    if (std::is_same::value) {
-        sumf = 0.0f;
+    if constexpr (std::is_same::value) {
+        const half2 * x2 = (const half2 *) x;
 
-        for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
-            const float2 tmpx = __half22float2(x2[col2]);
-            const float2 tmpy = y2[col2];
-            sumf += tmpx.x * tmpy.x;
-            sumf += tmpx.y * tmpy.y;
-        }
-    } else {
+        if (std::is_same::value) {
+            sumf = 0.0f;
+
+            for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+                const float2 tmpx = __half22float2(x2[col2]);
+                const float2 tmpy = y2[col2];
+                sumf += tmpx.x * tmpy.x;
+                sumf += tmpx.y * tmpy.y;
+            }
+        } else {
 #ifdef FP16_AVAILABLE
-        half2 sumh2 = make_half2(0.0f, 0.0f);
+            half2 sumh2 = make_half2(0.0f, 0.0f);
 
-        for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
-            const float2 tmp = y2[col2];
-            sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
-        }
+            for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+                const float2 tmp = y2[col2];
+                sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
+            }
 
-        sumf = __low2float(sumh2) + __high2float(sumh2);
+            sumf = __low2float(sumh2) + __high2float(sumh2);
 #else
-        NO_DEVICE_CODE;
+            NO_DEVICE_CODE;
 #endif // FP16_AVAILABLE
+        }
+    } else if constexpr (std::is_same::value) {
+        const int * x2 = (const int *) x;
+        sumf = 0.0f;
+
+        for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+            const int    tmpx = x2[col2];
+            const float2 tmpy = y2[col2];
+            sumf += float(reinterpret_cast(&tmpx)[0]) * tmpy.x;
+            sumf += float(reinterpret_cast(&tmpx)[1]) * tmpy.y;
+        }
+    } else {
+        static_assert(std::is_same::value, "unsupported type");
     }
 
     sumf = warp_reduce_sum(sumf);
@@ -71,9 +86,9 @@ static __global__ void mul_mat_vec(
     dst[row] = sumf;
 }
 
-template 
+template 
 static void launch_mul_mat_vec_cuda(
-        const half * x, const float * y, float * dst,
+        const T * x, const float * y, float * dst,
         const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
         cudaStream_t stream) {
@@ -97,35 +112,35 @@ static void launch_mul_mat_vec_cuda(
     const dim3 block_dims(block_size_best, 1, 1);
     switch (block_size_best) {
         case   32: {
-            mul_mat_vec<<>>
+            mul_mat_vec<<>>
                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
         } break;
         case   64: {
-            mul_mat_vec<<>>
+            mul_mat_vec<<>>
                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
         } break;
         case   96: {
-            mul_mat_vec<<>>
+            mul_mat_vec<<>>
                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
         } break;
         case  128: {
-            mul_mat_vec<<>>
+            mul_mat_vec<<>>
                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
         } break;
         case  160: {
-            mul_mat_vec<<>>
+            mul_mat_vec<<>>
                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
         } break;
         case  192: {
-            mul_mat_vec<<>>
+            mul_mat_vec<<>>
                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
         } break;
         case  224: {
-            mul_mat_vec<<>>
+            mul_mat_vec<<>>
                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
         } break;
         case  256: {
-            mul_mat_vec<<>>
+            mul_mat_vec<<>>
                 (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
         } break;
         default: {
@@ -134,25 +149,25 @@ static void launch_mul_mat_vec_cuda(
     }
 }
 
+template
 static void mul_mat_vec_cuda(
-        const half * x, const float * y, float * dst,
+        const T * x, const float * y, float * dst,
         const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
         enum ggml_prec prec, cudaStream_t stream) {
     switch (prec) {
         case GGML_PREC_DEFAULT: {
-            launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
+            launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
                 stride_channel_x, stride_channel_y, stride_channel_dst, stream);
         } break;
         case GGML_PREC_F32: {
-            launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
+            launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
                 stride_channel_x, stride_channel_y, stride_channel_dst, stream);
         } break;
     }
 }
 
 void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->type  == GGML_TYPE_F32);
 
@@ -164,7 +179,6 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
 
-    const half  * src0_d = (const half  *) src0->data;
     const float * src1_d = (const float *) src1->data;
     float       *  dst_d = (float       *)  dst->data;
 
@@ -181,7 +195,20 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
     const int64_t channel_stride_y   = src1->nb[2] / ggml_type_size(src1->type);
     const int64_t channel_stride_dst =  dst->nb[2] / ggml_type_size( dst->type);
 
-    mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
+    switch (src0->type) {
+        case GGML_TYPE_F16: {
+            const half * src0_d = (const half *) src0->data;
+            mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
+                channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
+        } break;
+        case GGML_TYPE_BF16: {
+            const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
+            mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
+                channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
+        } break;
+        default:
+            GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
+    }
 }
 
 void ggml_cuda_op_mul_mat_vec(
@@ -190,7 +217,6 @@ void ggml_cuda_op_mul_mat_vec(
     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
     const int64_t src1_padded_row_size, cudaStream_t stream) {
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->type  == GGML_TYPE_F32);
 
@@ -211,8 +237,20 @@ void ggml_cuda_op_mul_mat_vec(
     const int64_t channel_stride_y   = 0;
     const int64_t channel_stride_dst = 0;
 
-    mul_mat_vec_cuda((const half *) src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
-        nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
+    switch (src0->type) {
+        case GGML_TYPE_F16: {
+            const half * src0_d = (const half *) src0_dd_i;
+            mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
+                nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
+        } break;
+        case GGML_TYPE_BF16: {
+            const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
+            mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
+                nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
+        } break;
+        default:
+            GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
+    }
 
     GGML_UNUSED(ctx);
     GGML_UNUSED(src1);
diff --git a/ggml/src/ggml-cuda/vendors/cuda.h b/ggml/src/ggml-cuda/vendors/cuda.h
index db9f6a165d0..1746b073203 100644
--- a/ggml/src/ggml-cuda/vendors/cuda.h
+++ b/ggml/src/ggml-cuda/vendors/cuda.h
@@ -3,6 +3,7 @@
 #include 
 #include 
 #include 
+#include 
 #include 
 
 #if CUDART_VERSION < 11020
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
index 3205534d66f..c905b15d76c 100644
--- a/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ggml/src/ggml-cuda/vendors/hip.h
@@ -3,6 +3,7 @@
 #include 
 #include 
 #include 
+#include 
 #ifdef __HIP_PLATFORM_AMD__
 // for rocblas_initialize()
 #include "rocblas/rocblas.h"
@@ -121,6 +122,8 @@
     #define __has_builtin(x) 0
 #endif
 
+typedef hip_bfloat16 nv_bfloat16;
+
 typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
 typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
 static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h
index 1604b8229d5..6cc1b69ee33 100644
--- a/ggml/src/ggml-cuda/vendors/musa.h
+++ b/ggml/src/ggml-cuda/vendors/musa.h
@@ -3,6 +3,7 @@
 #include 
 #include 
 #include 
+#include 
 #include 
 #define CUBLAS_COMPUTE_16F CUDA_R_16F
 #define CUBLAS_COMPUTE_32F CUDA_R_32F
@@ -132,3 +133,5 @@
 #define cudaKernelNodeParams musaKernelNodeParams
 #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
 #define cudaStreamEndCapture musaStreamEndCapture
+
+typedef mt_bfloat16 nv_bfloat16;
diff --git a/ggml/src/ggml-cuda/wkv6.cu b/ggml/src/ggml-cuda/wkv6.cu
index 42578341a38..bbdafbee581 100644
--- a/ggml/src/ggml-cuda/wkv6.cu
+++ b/ggml/src/ggml-cuda/wkv6.cu
@@ -73,9 +73,9 @@ void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
     const float * s_d  = (const float *)dst->src[5]->data;
 
     const int64_t B = dst->src[5]->ne[1];
-    const int64_t T = dst->src[0]->ne[3];
+    const int64_t T = dst->src[0]->ne[2];
     const int64_t C = dst->ne[0];
-    const int64_t H = dst->src[0]->ne[2];
+    const int64_t H = dst->src[0]->ne[1];
 
     float * dst_d = (float *)dst->data;
 
diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt
index b15fbd24d6b..d090ba9bd98 100644
--- a/ggml/src/ggml-hip/CMakeLists.txt
+++ b/ggml/src/ggml-hip/CMakeLists.txt
@@ -70,7 +70,9 @@ ggml_add_backend_library(ggml-hip
                         )
 
 # TODO: do not use CUDA definitions for HIP
-target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
+if (NOT GGML_BACKEND_DL)
+    target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
+endif()
 
 add_compile_definitions(GGML_USE_HIP)
 
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index 549772c57c9..eab017889c9 100644
--- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h
@@ -3,6 +3,8 @@
 // GGML internal header
 
 #include "ggml.h"
+#include "gguf.h"
+
 #include 
 #include 
 #include  // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
@@ -551,22 +553,15 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
 #define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
 #define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
 
-// expose GGUF internals for test code
-
-GGML_API size_t gguf_type_size(enum gguf_type type);
-
-GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
-
-struct gguf_buf {
-    void * data;
-    size_t size;
-    size_t offset;
-};
-GGML_API struct gguf_buf gguf_buf_init(size_t size);
-GGML_API void gguf_buf_free(struct gguf_buf buf);
-
-GGML_API void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta);
-
 #ifdef __cplusplus
 }
 #endif
+
+#ifdef __cplusplus
+#include 
+
+// expose GGUF internals for test code
+GGML_API size_t gguf_type_size(enum gguf_type type);
+GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
+GGML_API void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & buf, bool only_meta);
+#endif // __cplusplus
diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt
new file mode 100644
index 00000000000..45328a65793
--- /dev/null
+++ b/ggml/src/ggml-opencl/CMakeLists.txt
@@ -0,0 +1,147 @@
+find_package(OpenCL REQUIRED)
+find_package(Python3 REQUIRED)
+
+set(TARGET_NAME ggml-opencl)
+
+ggml_add_backend_library(${TARGET_NAME}
+                         ggml-opencl.cpp
+                         ../../include/ggml-opencl.h)
+target_link_libraries(${TARGET_NAME} PRIVATE ${OpenCL_LIBRARIES})
+target_include_directories(${TARGET_NAME} PRIVATE ${OpenCL_INCLUDE_DIRS})
+
+if (GGML_OPENCL_PROFILING)
+    message(STATUS "OpenCL profiling enabled (increases CPU overhead)")
+    add_compile_definitions(GGML_OPENCL_PROFILING)
+endif ()
+
+add_compile_definitions(GGML_OPENCL_SOA_Q)
+
+if (GGML_OPENCL_USE_ADRENO_KERNELS)
+    message(STATUS "OpenCL will use matmul kernels optimized for Adreno")
+    add_compile_definitions(GGML_OPENCL_USE_ADRENO_KERNELS)
+endif ()
+
+if (GGML_OPENCL_EMBED_KERNELS)
+    add_compile_definitions(GGML_OPENCL_EMBED_KERNELS)
+
+    set(OPENCL_CL_SOURCE_EMBED         "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl.cl.h")
+    set(OPENCL_MM_CL_SOURCE_EMBED      "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_mm.cl.h")
+    set(OPENCL_CVT_CL_SOURCE_EMBED     "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_cvt.cl.h")
+
+    set(OPENCL_GEMV_NOSHUFFLE_SOURCE_EMBED             "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_gemv_noshuffle.cl.h")
+    set(OPENCL_GEMV_NOSHUFFLE_GENERAL_SOURCE_EMBED     "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_gemv_noshuffle_general.cl.h")
+    set(OPENCL_MUL_MAT_Ab_Bi_8x4_SOURCE_EMBED          "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_mul_mat_Ab_Bi_8x4.cl.h")
+    set(OPENCL_TRANSPOSE_16_SOURCE_EMBED               "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_transpose_16.cl.h")
+    set(OPENCL_TRANSPOSE_32_SOURCE_EMBED               "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_transpose_32.cl.h")
+    set(OPENCL_TRANSPOSE_32_16_SOURCE_EMBED            "${CMAKE_BINARY_DIR}/autogenerated/ggml-opencl_transpose_32_16.cl.h")
+
+    set(EMBED_KERNEL_SCRIPT             "${CMAKE_CURRENT_SOURCE_DIR}/kernels/embed_kernel.py")
+    file(MAKE_DIRECTORY                 "${CMAKE_BINARY_DIR}/autogenerated")
+
+    include_directories("${CMAKE_BINARY_DIR}/autogenerated")
+
+    # Python must be accessible from command line
+    add_custom_command(
+        OUTPUT ${OPENCL_CL_SOURCE_EMBED}
+        COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
+            ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl.cl
+            ${OPENCL_CL_SOURCE_EMBED}
+        DEPENDS kernels/ggml-opencl.cl ${EMBED_KERNEL_SCRIPT}
+        COMMENT "Generate ggml-opencl.cl.h"
+    )
+
+    add_custom_command(
+        OUTPUT ${OPENCL_MM_CL_SOURCE_EMBED}
+        COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
+            ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_mm.cl
+            ${OPENCL_MM_CL_SOURCE_EMBED}
+        DEPENDS kernels/ggml-opencl_mm.cl ${EMBED_KERNEL_SCRIPT}
+        COMMENT "Generate ggml-opencl_mm.cl.h"
+    )
+
+    add_custom_command(
+        OUTPUT ${OPENCL_CVT_CL_SOURCE_EMBED}
+        COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
+            ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_cvt.cl
+            ${OPENCL_CVT_CL_SOURCE_EMBED}
+        DEPENDS kernels/ggml-opencl_cvt.cl ${EMBED_KERNEL_SCRIPT}
+        COMMENT "Generate ggml-opencl_cvt.cl.h"
+    )
+
+    add_custom_command(
+        OUTPUT ${OPENCL_GEMV_NOSHUFFLE_SOURCE_EMBED}
+        COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
+            ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_gemv_noshuffle.cl
+            ${OPENCL_GEMV_NOSHUFFLE_SOURCE_EMBED}
+        DEPENDS kernels/ggml-opencl_gemv_noshuffle.cl ${EMBED_KERNEL_SCRIPT}
+        COMMENT "Generate ggml-opencl_gemv_noshuffle.cl.h"
+    )
+
+    add_custom_command(
+        OUTPUT ${OPENCL_GEMV_NOSHUFFLE_GENERAL_SOURCE_EMBED}
+        COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
+            ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_gemv_noshuffle_general.cl
+            ${OPENCL_GEMV_NOSHUFFLE_GENERAL_SOURCE_EMBED}
+        DEPENDS kernels/ggml-opencl_gemv_noshuffle_general.cl ${EMBED_KERNEL_SCRIPT}
+        COMMENT "Generate ggml-opencl_gemv_noshuffle_general.cl.h"
+    )
+
+    add_custom_command(
+        OUTPUT ${OPENCL_MUL_MAT_Ab_Bi_8x4_SOURCE_EMBED}
+        COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
+            ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl
+            ${OPENCL_MUL_MAT_Ab_Bi_8x4_SOURCE_EMBED}
+        DEPENDS kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl ${EMBED_KERNEL_SCRIPT}
+        COMMENT "Generate ggml-opencl_mul_mat_Ab_Bi_8x4.cl.cl.h"
+    )
+
+    add_custom_command(
+        OUTPUT ${OPENCL_TRANSPOSE_16_SOURCE_EMBED}
+        COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
+            ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_transpose_16.cl
+            ${OPENCL_TRANSPOSE_16_SOURCE_EMBED}
+        DEPENDS kernels/ggml-opencl_transpose_16.cl ${EMBED_KERNEL_SCRIPT}
+        COMMENT "Generate ggml-opencl_transpose_16.cl.h"
+    )
+
+    add_custom_command(
+        OUTPUT ${OPENCL_TRANSPOSE_32_SOURCE_EMBED}
+        COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
+            ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_transpose_32.cl
+            ${OPENCL_TRANSPOSE_32_SOURCE_EMBED}
+        DEPENDS kernels/ggml-opencl_transpose_32.cl ${EMBED_KERNEL_SCRIPT}
+        COMMENT "Generate ggml-opencl_transpose_32.cl.h"
+    )
+
+    add_custom_command(
+        OUTPUT ${OPENCL_TRANSPOSE_32_16_SOURCE_EMBED}
+        COMMAND ${Python3_EXECUTABLE} ${EMBED_KERNEL_SCRIPT}
+            ${CMAKE_CURRENT_SOURCE_DIR}/kernels/ggml-opencl_transpose_32_16.cl
+            ${OPENCL_TRANSPOSE_32_16_SOURCE_EMBED}
+        DEPENDS kernels/ggml-opencl_transpose_32_16.cl ${EMBED_KERNEL_SCRIPT}
+        COMMENT "Generate ggml-opencl_transpose_32_16.cl.h"
+    )
+
+    target_sources(${TARGET_NAME} PRIVATE
+                   ${OPENCL_CL_SOURCE_EMBED}
+                   ${OPENCL_MM_CL_SOURCE_EMBED}
+                   ${OPENCL_CVT_CL_SOURCE_EMBED}
+                   ${OPENCL_GEMV_NOSHUFFLE_SOURCE_EMBED}
+                   ${OPENCL_GEMV_NOSHUFFLE_GENERAL_SOURCE_EMBED}
+                   ${OPENCL_MUL_MAT_Ab_Bi_8x4_SOURCE_EMBED}
+                   ${OPENCL_TRANSPOSE_16_SOURCE_EMBED}
+                   ${OPENCL_TRANSPOSE_32_SOURCE_EMBED}
+                   ${OPENCL_TRANSPOSE_32_16_SOURCE_EMBED})
+else ()
+    # copy ggml-opencl.cl to bin directory
+    configure_file(kernels/ggml-opencl.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl.cl COPYONLY)
+    configure_file(kernels/ggml-opencl_mm.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_mm.cl COPYONLY)
+    configure_file(kernels/ggml-opencl_cvt.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_cvt.cl COPYONLY)
+
+    configure_file(kernels/ggml-opencl_gemv_noshuffle.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_gemv_noshuffle.cl COPYONLY)
+    configure_file(kernels/ggml-opencl_gemv_noshuffle_general.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_gemv_noshuffle_general.cl COPYONLY)
+    configure_file(kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_mul_mat_Ab_Bi_8x4.cl COPYONLY)
+    configure_file(kernels/ggml-opencl_transpose_16.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_transpose_16.cl COPYONLY)
+    configure_file(kernels/ggml-opencl_transpose_32.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_transpose_32.cl COPYONLY)
+    configure_file(kernels/ggml-opencl_transpose_32_16.cl ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-opencl_transpose_32_16.cl COPYONLY)
+endif ()
diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp
new file mode 100644
index 00000000000..ed90e471ac0
--- /dev/null
+++ b/ggml/src/ggml-opencl/ggml-opencl.cpp
@@ -0,0 +1,4004 @@
+#define CL_TARGET_OPENCL_VERSION 220
+#define CL_USE_DEPRECATED_OPENCL_1_2_APIS
+
+// suppress warnings in CL headers for GCC and Clang
+#pragma GCC diagnostic ignored "-Woverlength-strings"
+#ifdef __clang__
+#pragma GCC diagnostic ignored "-Wgnu-anonymous-struct"
+#endif
+
+#include "ggml-opencl.h"
+#include "ggml-backend.h"
+#include "ggml-impl.h"
+#include "ggml-backend-impl.h"
+#include "ggml.h"
+
+#include 
+
+#include 
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#undef MIN
+#undef MAX
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+
+#define UNUSED(x) (void)(x)
+
+#define CL_CHECK(err)                                               \
+    do {                                                            \
+        cl_int err_ = (err);                                        \
+        if (err_ != CL_SUCCESS) {                                   \
+            GGML_LOG_ERROR("ggml_opencl: %s error %d at %s:%d\n",  \
+                #err, err_, __FILE__, __LINE__);                    \
+            GGML_ASSERT(0);                                         \
+        }                                                           \
+    } while (0)
+
+//------------------------------------------------------------------------------
+// OpenCL
+//------------------------------------------------------------------------------
+
+bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor);
+
+enum GPU_FAMILY {
+    ADRENO,
+    INTEL,
+    UNKNOWN,
+};
+
+enum ADRENO_GPU_GEN {
+    ADRENO_UNKNOWN,
+    A7X,
+    A8X,
+    X1E,
+};
+
+static ADRENO_GPU_GEN get_adreno_gpu_gen(const char *device_name) {
+    if (strstr(device_name, "730") ||
+        strstr(device_name, "740") ||
+        strstr(device_name, "750")) {
+        return ADRENO_GPU_GEN::A7X;
+    }
+
+    if (strstr(device_name, "830")) {
+        return ADRENO_GPU_GEN::A8X;
+    }
+
+    if (strstr(device_name, "X1")) {
+        return ADRENO_GPU_GEN::X1E;
+    }
+
+    return ADRENO_GPU_GEN::ADRENO_UNKNOWN;
+}
+
+static int get_adreno_cl_compiler_version(const char *driver_version) {
+    std::string driver_ver_str(driver_version);
+    size_t compiler_ver_pos = driver_ver_str.find("E031");
+    size_t compiler_ver_len = 13;
+    size_t compiler_ver_offset = 5;
+
+    if (compiler_ver_pos == std::string::npos) {
+        compiler_ver_pos = driver_ver_str.find("DX");
+        if (compiler_ver_pos == std::string::npos) {
+            return -1;
+        }
+        compiler_ver_len = 11;
+        compiler_ver_offset = 3;
+    }
+
+    std::string compiler_ver_str = driver_ver_str.substr(compiler_ver_pos, compiler_ver_len);
+    std::string major_ver_str = compiler_ver_str.substr(compiler_ver_offset, 2);
+    return std::atoi(major_ver_str.c_str());
+}
+
+// backend device context
+struct ggml_backend_opencl_device_context {
+    cl_platform_id platform;
+    std::string platform_name;
+
+    cl_device_id device;
+    std::string device_name;
+};
+
+// backend context
+struct ggml_backend_opencl_context {
+    cl_device_id device;
+    std::string device_name;
+
+    std::string driver_version;
+
+    GPU_FAMILY gpu_family;
+    ADRENO_GPU_GEN adreno_gen;
+
+    cl_int alignment;
+    size_t max_alloc_size;
+    bool fp16_support;
+
+    int adreno_wave_size;
+
+    cl_context context;
+    cl_command_queue queue;
+
+    cl_program program;
+    cl_program program_1;
+    cl_program program_2;
+
+    cl_kernel kernel_add, kernel_add_row;
+    cl_kernel kernel_mul, kernel_mul_row;
+    cl_kernel kernel_scale;
+    cl_kernel kernel_silu, kernel_silu_4;
+    cl_kernel kernel_gelu, kernel_gelu_4;
+    cl_kernel kernel_relu;
+    cl_kernel kernel_clamp;
+    cl_kernel kernel_norm;
+    cl_kernel kernel_rms_norm;
+    cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
+    cl_kernel kernel_soft_max, kernel_soft_max_4;
+    cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
+    cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
+    cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32;
+    cl_kernel kernel_mul_mat_f32_f32;
+    cl_kernel kernel_mul_mat_f16_f16;
+    cl_kernel kernel_mul_mat_f16_f32_1row;
+    cl_kernel kernel_mul_mat_f16_f32;
+    cl_kernel kernel_mul_mat_f16_f32_l4;
+    cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
+    cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0, kernel_mul_mat_q4_0_f32_flat;
+    cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
+    cl_kernel kernel_convert_block_q4_0_noshuffle, kernel_mul_mat_q4_0_f32_flat_v0,
+              kernel_mul_mat_q4_0_f32_flat_img_v0;
+    cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
+    cl_kernel kernel_mul_mv_q6_K_f32;
+
+#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+    // Transpose kernels
+    cl_program program_transpose_32;
+    cl_program program_transpose_32_16;
+    cl_program program_transpose_16;
+    cl_kernel kernel_transpose_32;
+    cl_kernel kernel_transpose_32_16;
+    cl_kernel kernel_transpose_16;
+
+    cl_mem A_s_d_max;            // max scale buffer size for transpose
+    cl_mem A_q_d_max;            // max weight buffer size for transpose
+    cl_mem B_d_max;              // max activation buffer size for transpose
+
+    // Gemm and Gemv related programs, kernels, etc
+    cl_program program_CL_gemm;
+    cl_program program_CL_gemv_general;
+    cl_program program_CL_gemv_4096_1_11008;
+    cl_program program_CL_gemv_4096_1_4096;
+    cl_program program_CL_gemv_11008_1_4096;
+    cl_program program_CL_gemv_32000_1_4096;
+    cl_kernel CL_mul_mat_Ab_Bi_8x4;
+    cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general;
+    cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008;
+    cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096;
+    cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096;
+    cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096;
+#endif // GGML_OPENCL_USE_ADRENO_KERNELS
+};
+
+static ggml_backend_device                 g_ggml_backend_opencl_device;
+static ggml_backend_opencl_device_context  g_ggml_ctx_dev_main {
+    /*.platform         =*/ nullptr,
+    /*.platform_nane    =*/ "",
+    /*.device           =*/ nullptr,
+    /*.device_name      =*/ "",
+};
+
+static int ggml_backend_opencl_n_devices = 0;
+
+// Profiling
+#ifdef GGML_OPENCL_PROFILING
+struct ProfilingInfo {
+    std::string op_name;
+    std::string kernel_name;
+    // Kernel execution time in nanoseconds.
+    cl_ulong duration_ns;
+    // Global and local work sizes.
+    size_t global_size[3];
+    size_t local_size[3];
+    // Op output size.
+    size_t output_size[4];
+};
+
+std::vector g_profiling_info;
+#endif
+
+inline std::string read_file(const std::string &path) {
+  std::ifstream ifs(path);
+  if (!ifs) {
+    return "";
+  }
+  std::string text;
+  ifs.seekg(0, std::ios::end);
+  text.resize(ifs.tellg());
+  ifs.seekg(0, std::ios::beg);
+  ifs.read(&text[0], text.size());
+  return text;
+}
+
+static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer, const std::string &compile_opts) {
+    cl_program p;
+    char *program_log;
+    size_t program_size;
+    size_t log_size;
+    int err;
+
+    program_size = strlen(program_buffer);
+
+    p = clCreateProgramWithSource(ctx, 1, (const char**)&program_buffer, &program_size, &err);
+    if(err < 0) {
+        GGML_LOG_ERROR("OpenCL error creating program");
+        exit(1);
+    }
+
+    err = clBuildProgram(p, 0, NULL, compile_opts.c_str(), NULL, NULL);
+    if(err < 0) {
+        clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
+        program_log = (char*) malloc(log_size + 1);
+        program_log[log_size] = '\0';
+        clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, log_size + 1, program_log, NULL);
+        GGML_LOG_ERROR("ggml_opencl: kernel compile error:\n\n%s\n", program_log);
+        free(program_log);
+        exit(1);
+    }
+
+    return p;
+}
+
+static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
+    static bool initialized = false;
+    static ggml_backend_opencl_context *backend_ctx = nullptr;
+
+    if (initialized) {
+        return backend_ctx;
+    }
+
+    ggml_backend_opencl_device_context *dev_ctx = (ggml_backend_opencl_device_context *)dev->context;
+    GGML_ASSERT(dev_ctx);
+    GGML_ASSERT(dev_ctx->platform == nullptr);
+    GGML_ASSERT(dev_ctx->device == nullptr);
+    GGML_ASSERT(backend_ctx == nullptr);
+
+    initialized = true;
+    backend_ctx = new ggml_backend_opencl_context();
+    backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN;
+
+    cl_int err;
+
+#ifdef GGML_PROFILE_OPENCL
+    GGML_LOG_INFO("ggml_opencl: OpenCL profiling enabled\n");
+#endif
+
+    struct cl_device;
+    struct cl_platform {
+        cl_platform_id id;
+        unsigned number;
+        char name[128];
+        char vendor[128];
+        struct cl_device * devices;
+        unsigned n_devices;
+        struct cl_device * default_device;
+    };
+
+    struct cl_device {
+        struct cl_platform * platform;
+        cl_device_id id;
+        unsigned number;
+        cl_device_type type;
+        char name[128];
+    };
+
+    enum { NPLAT = 16, NDEV = 16 };
+
+    struct cl_platform platforms[NPLAT];
+    unsigned n_platforms = 0;
+    struct cl_device devices[NDEV];
+    unsigned n_devices = 0;
+    struct cl_device * default_device = NULL;
+
+    cl_platform_id platform_ids[NPLAT];
+    if (clGetPlatformIDs(NPLAT, platform_ids, &n_platforms) != CL_SUCCESS) {
+        GGML_LOG_ERROR("ggml_opencl: plaform IDs not available.\n");
+        return backend_ctx;
+    }
+
+    for (unsigned i = 0; i < n_platforms; i++) {
+        struct cl_platform * p = &platforms[i];
+        p->number = i;
+        p->id = platform_ids[i];
+        CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_NAME, sizeof(p->name), &p->name, NULL));
+        CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_VENDOR, sizeof(p->vendor), &p->vendor, NULL));
+
+        cl_device_id device_ids[NDEV];
+        cl_int clGetDeviceIDsError = clGetDeviceIDs(p->id, CL_DEVICE_TYPE_ALL, NDEV, device_ids, &p->n_devices);
+        if (clGetDeviceIDsError == CL_DEVICE_NOT_FOUND) {
+            p->n_devices = 0;
+        } else {
+            CL_CHECK(clGetDeviceIDsError);
+        }
+        p->devices = p->n_devices > 0 ? &devices[n_devices] : NULL;
+        p->default_device = NULL;
+
+        for (unsigned j = 0; j < p->n_devices; j++) {
+            struct cl_device * d = &devices[n_devices];
+            d->number = n_devices++;
+            d->id = device_ids[j];
+            d->platform = p;
+            CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_NAME, sizeof(d->name), &d->name, NULL));
+            CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_TYPE, sizeof(d->type), &d->type, NULL));
+
+            if (p->default_device == NULL && d->type == CL_DEVICE_TYPE_GPU) {
+                p->default_device = d;
+            }
+        }
+
+        if (default_device == NULL && p->default_device != NULL) {
+            default_device = p->default_device;
+        }
+    }
+
+    if (n_devices == 0) {
+        GGML_LOG_ERROR("ggml_opencl: could find any OpenCL devices.\n");
+        return backend_ctx;
+    }
+
+    char * user_platform_string = getenv("GGML_OPENCL_PLATFORM");
+    char * user_device_string = getenv("GGML_OPENCL_DEVICE");
+    int user_platform_number = -1;
+    int user_device_number = -1;
+
+    unsigned n;
+    if (user_platform_string != NULL && sscanf(user_platform_string, " %u", &n) == 1 && n < n_platforms) {
+        user_platform_number = (int)n;
+    }
+    if (user_device_string != NULL && sscanf(user_device_string, " %u", &n) == 1 && n < n_devices) {
+        user_device_number = (int)n;
+    }
+    if (user_platform_number != -1 && user_device_number != -1) {
+        cl_platform* platform = &platforms[user_platform_number];
+        if ((unsigned)user_device_number >= platform->n_devices) {
+            GGML_LOG_ERROR("ggml_opencl: invalid device number %d\n", user_device_number);
+            exit(1);
+        }
+        default_device = &platform->devices[user_device_number];
+    } else {
+
+        struct cl_device * selected_devices = devices;
+        unsigned n_selected_devices = n_devices;
+
+        if (user_platform_number == -1 && user_platform_string != NULL && user_platform_string[0] != 0) {
+            for (unsigned i = 0; i < n_platforms; i++) {
+                struct cl_platform * p = &platforms[i];
+                if (strstr(p->name, user_platform_string) != NULL ||
+                    strstr(p->vendor, user_platform_string) != NULL) {
+                    user_platform_number = (int)i;
+                    break;
+                }
+            }
+            if (user_platform_number == -1) {
+                GGML_LOG_ERROR("ggml_opencl: no platform matching '%s' was found.\n", user_platform_string);
+                exit(1);
+            }
+        }
+        if (user_platform_number != -1) {
+            struct cl_platform * p = &platforms[user_platform_number];
+            selected_devices = p->devices;
+            n_selected_devices = p->n_devices;
+            default_device = p->default_device;
+            if (n_selected_devices == 0) {
+                GGML_LOG_ERROR("ggml_opencl: selected platform '%s' does not have any devices.\n", p->name);
+                exit(1);
+            }
+        }
+
+        if (user_device_number == -1 && user_device_string != NULL && user_device_string[0] != 0) {
+            for (unsigned i = 0; i < n_selected_devices; i++) {
+                struct cl_device * d = &selected_devices[i];
+                if (strstr(d->name, user_device_string) != NULL) {
+                    user_device_number = d->number;
+                    break;
+                }
+            }
+            if (user_device_number == -1) {
+                GGML_LOG_ERROR("ggml_opencl: no device matching '%s' was found.\n", user_device_string);
+                exit(1);
+            }
+        }
+        if (user_device_number != -1) {
+            selected_devices = &devices[user_device_number];
+            n_selected_devices = 1;
+            default_device = &selected_devices[0];
+        }
+
+        GGML_ASSERT(n_selected_devices > 0);
+
+        if (default_device == NULL) {
+            default_device = &selected_devices[0];
+        }
+    }
+
+    GGML_LOG_INFO("ggml_opencl: selecting platform: '%s'\n", default_device->platform->name);
+    GGML_LOG_INFO("ggml_opencl: selecting device: '%s'\n", default_device->name);
+    if (default_device->type != CL_DEVICE_TYPE_GPU) {
+        GGML_LOG_WARN("ggml_opencl: warning, not a GPU: '%s'.\n", default_device->name);
+    }
+
+    dev_ctx->platform = default_device->platform->id;
+    dev_ctx->device = default_device->id;
+    backend_ctx->device = default_device->id;
+
+    if (strstr(default_device->name, "Adreno")) {
+        backend_ctx->gpu_family = GPU_FAMILY::ADRENO;
+        backend_ctx->adreno_gen = get_adreno_gpu_gen(default_device->name);
+
+        // Default wave size is 128, A8x uses 64.
+        if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::A8X) {
+            backend_ctx->adreno_wave_size = 64;
+        } else if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::A7X ||
+                   backend_ctx->adreno_gen == ADRENO_GPU_GEN::X1E) {
+            backend_ctx->adreno_wave_size = 128;
+        } else {
+            backend_ctx->adreno_wave_size = 128;
+            GGML_LOG_WARN("ggml_opencl: Unsupported Adreno GPU: %s, "
+                "using wave size %d, "
+                "may not work as expected\n",
+                backend_ctx->device_name.c_str(), backend_ctx->adreno_wave_size);
+        }
+    } else if (strstr(default_device->name, "Intel")) {
+        backend_ctx->gpu_family = GPU_FAMILY::INTEL;
+    } else {
+        GGML_LOG_ERROR("Unsupported GPU: %s\n", default_device->name);
+        backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN;
+        return backend_ctx;
+    }
+
+#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+    if (backend_ctx->gpu_family != GPU_FAMILY::ADRENO) {
+        GGML_LOG_ERROR("ggml_opencl: Adreno-specific kernels should not be enabled for non-Adreno GPUs; "
+            "run on an Adreno GPU or recompile with CMake option `-DGGML_OPENCL_USE_ADRENO_KERNELS=OFF`\n");
+        return backend_ctx;
+    }
+#endif
+
+    // Populate backend device name
+    dev_ctx->platform_name = default_device->platform->name;
+    dev_ctx->device_name = default_device->name;
+    backend_ctx->device_name = default_device->name;
+
+    // A local ref of cl_device_id for convenience
+    cl_device_id device = backend_ctx->device;
+
+    // Check device OpenCL version, OpenCL 2.0 or above is required
+    size_t device_ver_str_size;
+    clGetDeviceInfo(device, CL_DEVICE_VERSION, 0, NULL, &device_ver_str_size);
+    char *device_ver_buffer = (char *)alloca(device_ver_str_size + 1);
+    clGetDeviceInfo(device, CL_DEVICE_VERSION, device_ver_str_size, device_ver_buffer, NULL);
+    device_ver_buffer[device_ver_str_size] = '\0';
+    GGML_LOG_INFO("ggml_opencl: device OpenCL version: %s\n", device_ver_buffer);
+
+    if (strstr(device_ver_buffer, "OpenCL 2") == NULL &&
+        strstr(device_ver_buffer, "OpenCL 3") == NULL) {
+        GGML_LOG_ERROR("ggml_opencl: OpenCL 2.0 or above is required\n");
+        return backend_ctx;
+    }
+
+    // Check driver version
+    size_t driver_version_str_size;
+    clGetDeviceInfo(device, CL_DRIVER_VERSION, 0, NULL, &driver_version_str_size);
+    char *driver_version = (char *)alloca(driver_version_str_size + 1);
+    clGetDeviceInfo(device, CL_DRIVER_VERSION, driver_version_str_size, driver_version, NULL);
+    driver_version[driver_version_str_size] = '\0';
+    GGML_LOG_INFO("ggml_opencl: OpenCL driver: %s\n", driver_version);
+    backend_ctx->driver_version = driver_version;
+
+    int adreno_cl_compiler_version = get_adreno_cl_compiler_version(driver_version);
+    bool has_vector_subgroup_broadcast =
+        adreno_cl_compiler_version >= 47 || adreno_cl_compiler_version == 17;
+    GGML_LOG_INFO("ggml_opencl: vector subgroup broadcast support: %s\n",
+        has_vector_subgroup_broadcast ? "true" : "false");
+
+    size_t ext_str_size;
+    clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, 0, NULL, &ext_str_size);
+    char *ext_buffer = (char *)alloca(ext_str_size + 1);
+    clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, ext_str_size, ext_buffer, NULL);
+    ext_buffer[ext_str_size] = '\0'; // ensure it is null terminated
+    // Check if ext_buffer contains cl_khr_fp16
+    backend_ctx->fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL;
+    GGML_LOG_INFO("ggml_opencl: device FP16 support: %s\n", backend_ctx->fp16_support ? "true" : "false");
+
+    // fp16 is required
+    if (!backend_ctx->fp16_support) {
+        GGML_LOG_ERROR("ggml_opencl: device does not support FP16\n");
+        return backend_ctx;
+    }
+
+    // If OpenCL 3.0 is supported, then check for cl_khr_subgroups, which becomes
+    // optional in OpenCL 3.0 (cl_khr_subgroup is mandatory in OpenCL 2.x)
+    if (strstr(device_ver_buffer, "OpenCL 3") &&
+        strstr(ext_buffer, "cl_khr_subgroups") == NULL &&
+        strstr(ext_buffer, "cl_intel_subgroups") == NULL) {
+        GGML_LOG_ERROR("ggml_opencl: device does not support subgroups (cl_khr_subgroups or cl_intel_subgroups) "
+            "(note that subgroups is an optional feature in OpenCL 3.0)\n");
+        return backend_ctx;
+    }
+
+    CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &backend_ctx->alignment, NULL));
+    GGML_LOG_INFO("ggml_opencl: mem base addr align: %u\n", backend_ctx->alignment);
+
+    clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL);
+    GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", backend_ctx->max_alloc_size/1024/1024);
+
+    // Check SVM.
+    cl_device_svm_capabilities svm_caps;
+    CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_SVM_CAPABILITIES, sizeof(cl_device_svm_capabilities), &svm_caps, 0));
+    GGML_LOG_INFO("ggml_opencl: SVM coarse grain buffer support: %s\n",
+        svm_caps & CL_DEVICE_SVM_COARSE_GRAIN_BUFFER ? "true" : "false");
+    GGML_LOG_INFO("ggml_opencl: SVM fine grain buffer support: %s\n",
+        svm_caps & CL_DEVICE_SVM_FINE_GRAIN_BUFFER ? "true" : "false");
+    GGML_LOG_INFO("ggml_opencl: SVM fine grain system support: %s\n",
+        svm_caps & CL_DEVICE_SVM_FINE_GRAIN_SYSTEM ? "true" : "false");
+    GGML_LOG_INFO("ggml_opencl: SVM atomics support: %s\n",
+        svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false");
+
+    // Print out configurations
+#ifdef GGML_OPENCL_SOA_Q
+    GGML_LOG_INFO("ggml_opencl: flattening quantized weights representation as struct of arrays (GGML_OPENCL_SOA_Q)\n");
+#endif // GGML_OPENCL_SOA_Q
+
+#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+    GGML_LOG_INFO("ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n");
+#endif // GGML_OPENCL_USE_ADRENO_KERNELS
+
+    cl_context_properties properties[] = {
+        (intptr_t)CL_CONTEXT_PLATFORM, (intptr_t)dev_ctx->platform, 0
+    };
+
+    CL_CHECK((backend_ctx->context = clCreateContext(properties, 1, &device, NULL, NULL, &err), err));
+
+    // A local ref of cl_context for convenience
+    cl_context context = backend_ctx->context;
+
+    //CL_CHECK((queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err),
+    //    (err != CL_INVALID_QUEUE_PROPERTIES && err != CL_INVALID_VALUE ? err :
+    //    (queue = clCreateCommandQueue(context, device, 0, &err), err)
+    //)));
+    cl_command_queue_properties command_queue_props = 0;
+#ifdef GGML_OPENCL_PROFILING
+    command_queue_props |= CL_QUEUE_PROFILING_ENABLE;
+#endif
+    CL_CHECK((backend_ctx->queue = clCreateCommandQueue(context, device, command_queue_props, &err), err));
+
+#ifdef GGML_OPENCL_EMBED_KERNELS
+    const std::string kernel_src {
+        #include "ggml-opencl.cl.h"
+    };
+#else
+    const std::string kernel_src = read_file("ggml-opencl.cl");
+#endif
+
+    std::string compile_opts =
+        "-cl-std=CL2.0 -cl-mad-enable -cl-unsafe-math-optimizations "
+        "-cl-finite-math-only -cl-fast-relaxed-math ";
+    backend_ctx->program = build_program_from_source(context, device, kernel_src.c_str(), compile_opts);
+
+    // Non matmul kernels.
+    CL_CHECK((backend_ctx->kernel_get_rows_f32       = clCreateKernel(backend_ctx->program, "kernel_get_rows_f32", &err), err));
+    CL_CHECK((backend_ctx->kernel_get_rows_f16       = clCreateKernel(backend_ctx->program, "kernel_get_rows_f16", &err), err));
+    CL_CHECK((backend_ctx->kernel_get_rows_q4_0      = clCreateKernel(backend_ctx->program, "kernel_get_rows_q4_0", &err), err));
+    CL_CHECK((backend_ctx->kernel_add                = clCreateKernel(backend_ctx->program, "kernel_add", &err), err));
+    CL_CHECK((backend_ctx->kernel_add_row            = clCreateKernel(backend_ctx->program, "kernel_add_row", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul                = clCreateKernel(backend_ctx->program, "kernel_mul", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul_row            = clCreateKernel(backend_ctx->program, "kernel_mul_row", &err), err));
+    CL_CHECK((backend_ctx->kernel_scale              = clCreateKernel(backend_ctx->program, "kernel_scale", &err), err));
+    CL_CHECK((backend_ctx->kernel_silu               = clCreateKernel(backend_ctx->program, "kernel_silu", &err), err));
+    CL_CHECK((backend_ctx->kernel_silu_4             = clCreateKernel(backend_ctx->program, "kernel_silu_4", &err), err));
+    CL_CHECK((backend_ctx->kernel_gelu               = clCreateKernel(backend_ctx->program, "kernel_gelu", &err), err));
+    CL_CHECK((backend_ctx->kernel_gelu_4             = clCreateKernel(backend_ctx->program, "kernel_gelu_4", &err), err));
+    CL_CHECK((backend_ctx->kernel_relu               = clCreateKernel(backend_ctx->program, "kernel_relu", &err), err));
+    CL_CHECK((backend_ctx->kernel_clamp              = clCreateKernel(backend_ctx->program, "kernel_clamp", &err), err));
+    CL_CHECK((backend_ctx->kernel_norm               = clCreateKernel(backend_ctx->program, "kernel_norm", &err), err));
+    CL_CHECK((backend_ctx->kernel_rms_norm           = clCreateKernel(backend_ctx->program, "kernel_rms_norm", &err), err));
+    CL_CHECK((backend_ctx->kernel_diag_mask_inf      = clCreateKernel(backend_ctx->program, "kernel_diag_mask_inf", &err), err));
+    CL_CHECK((backend_ctx->kernel_diag_mask_inf_8    = clCreateKernel(backend_ctx->program, "kernel_diag_mask_inf_8", &err), err));
+    CL_CHECK((backend_ctx->kernel_soft_max           = clCreateKernel(backend_ctx->program, "kernel_soft_max", &err), err));
+    CL_CHECK((backend_ctx->kernel_soft_max_4         = clCreateKernel(backend_ctx->program, "kernel_soft_max_4", &err), err));
+    CL_CHECK((backend_ctx->kernel_rope_norm_f32      = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f32", &err), err));
+    CL_CHECK((backend_ctx->kernel_rope_norm_f16      = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f16", &err), err));
+    CL_CHECK((backend_ctx->kernel_rope_neox_f32      = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f32", &err), err));
+    CL_CHECK((backend_ctx->kernel_rope_neox_f16      = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f16", &err), err));
+    CL_CHECK((backend_ctx->kernel_cpy_f16_f16        = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f16", &err), err));
+    CL_CHECK((backend_ctx->kernel_cpy_f16_f32        = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f32", &err), err));
+    CL_CHECK((backend_ctx->kernel_cpy_f32_f16        = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f16", &err), err));
+    CL_CHECK((backend_ctx->kernel_cpy_f32_f32        = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f32", &err), err));
+
+    // Matmul kernels.
+    CL_CHECK((backend_ctx->kernel_mul_mat_f32_f32        = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f32_f32", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul_mat_f16_f16        = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f16", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_1row   = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f32_1row", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32        = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f32", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_l4     = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f32_l4", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32       = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_v     = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32_v", &err), err));
+
+    CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_flat  = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32_flat", &err), err));
+    CL_CHECK((backend_ctx->kernel_convert_block_q4_0     = clCreateKernel(backend_ctx->program, "kernel_convert_block_q4_0", &err), err));
+    CL_CHECK((backend_ctx->kernel_restore_block_q4_0     = clCreateKernel(backend_ctx->program, "kernel_restore_block_q4_0", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32_8x_flat", &err), err));
+
+    // Load additional mulmat kernels.
+#ifdef GGML_OPENCL_EMBED_KERNELS
+    const std::string kernel_src_1 {
+        #include "ggml-opencl_mm.cl.h"
+    };
+#else
+    const std::string kernel_src_1 = read_file("ggml-opencl_mm.cl");
+#endif
+    backend_ctx->program_1 = build_program_from_source(context, device, kernel_src_1.c_str(), compile_opts);
+
+    CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_8x_flat      = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_1d_8x_flat", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_16x_flat     = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_1d_16x_flat", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul_mv_q6_K_f32                  = clCreateKernel(backend_ctx->program_1, "kernel_mul_mv_q6_K_f32", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_flat_v0         = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_flat_v0", &err), err));
+    CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_flat_img_v0     = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_flat_img_v0", &err), err));
+
+    // Load additional data conversion kernels.
+#ifdef GGML_OPENCL_EMBED_KERNELS
+    const std::string kernel_src_2 {
+        #include "ggml-opencl_cvt.cl.h"
+    };
+#else
+    const std::string kernel_src_2 = read_file("ggml-opencl_cvt.cl");
+#endif
+    backend_ctx->program_2 = build_program_from_source(context, device, kernel_src_2.c_str(), compile_opts);
+
+    CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle     = clCreateKernel(backend_ctx->program_2, "kernel_convert_block_q4_0_noshuffle", &err), err));
+
+    // Kernels for Adreno
+#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+#ifdef GGML_OPENCL_EMBED_KERNELS
+    const std::string transpose_32_src {
+        #include "ggml-opencl_transpose_32.cl.h"
+    };
+#else
+    const std::string transpose_32_src = read_file("ggml-opencl_transpose_32.cl");
+#endif
+    backend_ctx->program_transpose_32 = build_program_from_source(context, device, transpose_32_src.c_str(), compile_opts);
+    CL_CHECK((backend_ctx->kernel_transpose_32 = clCreateKernel(backend_ctx->program_transpose_32, "kernel_transpose_32", &err), err));
+
+#ifdef GGML_OPENCL_EMBED_KERNELS
+    const std::string transpose_32_16_src {
+        #include "ggml-opencl_transpose_32_16.cl.h"
+    };
+#else
+    const std::string transpose_32_16_src = read_file("ggml-opencl_transpose_32_16.cl");
+#endif
+    backend_ctx->program_transpose_32_16 = build_program_from_source(context, device, transpose_32_16_src.c_str(), compile_opts);
+    CL_CHECK((backend_ctx->kernel_transpose_32_16 = clCreateKernel(backend_ctx->program_transpose_32_16, "kernel_transpose_32_16", &err), err));
+
+#ifdef GGML_OPENCL_EMBED_KERNELS
+    const std::string transpose_16_src {
+        #include "ggml-opencl_transpose_16.cl.h"
+    };
+#else
+    const std::string transpose_16_src = read_file("ggml-opencl_transpose_16.cl");
+#endif
+    backend_ctx->program_transpose_16 = build_program_from_source(context, device, transpose_16_src.c_str(), compile_opts);
+    CL_CHECK((backend_ctx->kernel_transpose_16 = clCreateKernel(backend_ctx->program_transpose_16, "kernel_transpose_16", &err), err));
+
+    // Gemv general
+    std::string CL_gemv_compile_opts =
+        " -cl-std=CL2.0 "
+        " -cl-mad-enable "
+        " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size);
+    if (has_vector_subgroup_broadcast) {
+        CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT ";
+    }
+#ifdef GGML_OPENCL_EMBED_KERNELS
+    const std::string kernel_src_CL_gemv_general {
+        #include "ggml-opencl_gemv_noshuffle_general.cl.h"
+    };
+#else
+    const std::string kernel_src_CL_gemv_general = read_file("ggml-opencl_gemv_noshuffle_general.cl");
+#endif
+
+    backend_ctx->program_CL_gemv_general = build_program_from_source(
+        context, device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts);
+    CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general = clCreateKernel(backend_ctx->program_CL_gemv_general, "kernel_gemv_noshuffle", &err), err));
+
+    // Gemv 2048, 16384
+    CL_gemv_compile_opts =
+        " -cl-std=CL2.0 "
+        " -cl-mad-enable "
+        " -DLINE_STRIDE_A=2048 "
+        " -DBLOCK_STRIDE_A=16384 "
+        " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size);
+    if (has_vector_subgroup_broadcast) {
+        CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT ";
+    }
+#ifdef GGML_OPENCL_EMBED_KERNELS
+    const std::string kernel_src_CL_gemv {
+        #include "ggml-opencl_gemv_noshuffle.cl.h"
+    };
+#else
+    const std::string kernel_src_CL_gemv = read_file("ggml-opencl_gemv_noshuffle.cl");
+#endif
+
+    backend_ctx->program_CL_gemv_4096_1_4096 = build_program_from_source(
+        context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts);
+    CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_4096, "kernel_gemv_noshuffle", &err), err));
+
+    // Gemv 2048, 16384
+    CL_gemv_compile_opts =
+        " -cl-std=CL2.0 "
+        " -cl-mad-enable "
+        " -DLINE_STRIDE_A=2048 "
+        " -DBLOCK_STRIDE_A=16384 "
+        " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size);
+    if (has_vector_subgroup_broadcast) {
+        CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT ";
+    }
+
+    backend_ctx->program_CL_gemv_4096_1_11008 = build_program_from_source(
+        context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts);
+    CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_11008, "kernel_gemv_noshuffle", &err), err));
+
+    // Gemv 5504, 44032
+    CL_gemv_compile_opts =
+        " -cl-std=CL2.0 "
+        " -cl-mad-enable "
+        " -DLINE_STRIDE_A=5504 "
+        " -DBLOCK_STRIDE_A=44032 "
+        " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size);
+    if (has_vector_subgroup_broadcast) {
+        CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT ";
+    }
+
+    backend_ctx->program_CL_gemv_11008_1_4096 = build_program_from_source(
+        context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts);
+    CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_11008_1_4096, "kernel_gemv_noshuffle", &err), err));
+
+    // Gemv 16000, 128000
+    CL_gemv_compile_opts =
+        " -cl-std=CL2.0 "
+        " -cl-mad-enable "
+        " -DLINE_STRIDE_A=16000 "
+        " -DBLOCK_STRIDE_A=128000 "
+        " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size);
+    if (has_vector_subgroup_broadcast) {
+        CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT ";
+    }
+
+    backend_ctx->program_CL_gemv_32000_1_4096 = build_program_from_source(context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts);
+    CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_32000_1_4096, "kernel_gemv_noshuffle", &err), err));
+
+    // Gemm
+#ifdef GGML_OPENCL_EMBED_KERNELS
+    const std::string kernel_src_CL_gemm {
+        #include "ggml-opencl_mul_mat_Ab_Bi_8x4.cl.h"
+    };
+#else
+    const std::string kernel_src_CL_gemm = read_file("ggml-opencl_mul_mat_Ab_Bi_8x4.cl");
+#endif
+    backend_ctx->program_CL_gemm = build_program_from_source(context, device, kernel_src_CL_gemm.c_str(), compile_opts);
+    CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err));
+
+    // Allocate intermediate buffers and images
+    size_t max_A_q_d_bytes = 311164928;
+    size_t max_A_s_d_bytes = 38895616;
+    size_t max_B_d_bytes = 45088768;
+
+    CL_CHECK((backend_ctx->A_q_d_max = clCreateBuffer(context, 0, max_A_q_d_bytes, NULL, &err), err));
+    CL_CHECK((backend_ctx->A_s_d_max = clCreateBuffer(context, 0, max_A_s_d_bytes, NULL, &err), err));
+    CL_CHECK((backend_ctx->B_d_max   = clCreateBuffer(context, 0, max_B_d_bytes,   NULL, &err), err));
+#endif // GGML_OPENCL_USE_ADRENO_KERNELS
+
+    // For now we support a single devices
+    ggml_backend_opencl_n_devices = 1;
+
+    return backend_ctx;
+}
+
+static void ggml_cl2_free(void) {
+#ifdef GGML_OPENCL_PROFILING
+    FILE * fperf = fopen("cl_profiling.csv", "w");
+    if (!fperf) {
+        GGML_LOG_ERROR("Failed to open cl_profiling.csv\n");
+        return;
+    }
+
+    float total_kernel_time = 0;
+    fprintf(fperf, "op name, kernel name, duration (ms), global size, local size, output size\n");
+    for (const ProfilingInfo & info : g_profiling_info) {
+        total_kernel_time += info.duration_ns/1.e6f;
+        fprintf(fperf, "%s,%s,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n",
+            info.op_name.c_str(), info.kernel_name.c_str(), info.duration_ns/1.e6f,
+            info.global_size[0], info.global_size[1], info.global_size[2],
+            info.local_size[0], info.local_size[2], info.local_size[2],
+            info.output_size[0], info.output_size[1], info.output_size[2], info.output_size[3]);
+    }
+    fclose(fperf);
+
+    GGML_LOG_INFO("ggml_opencl: total kernel time: %f\n", total_kernel_time);
+#endif
+}
+
+//------------------------------------------------------------------------------
+// Tensor extra management
+//------------------------------------------------------------------------------
+struct ggml_tensor_extra_cl {
+    // The buffer object that holds the data.
+    cl_mem data_device;
+    // The offset into the buffer object. This is primarily for scratch buffer
+    // and view operation.
+    // NB: this offset no longer includes view offset (view_offs). Whenever this
+    // offset is used, view_offs should be considered.
+    cl_ulong offset;
+    // The actual size of the cl_mem object. This is needed when returning the
+    // block to the pool.
+    size_t actual_size;
+
+    void reset() {
+        data_device = nullptr;
+        offset = 0;
+        actual_size = 0;
+    }
+};
+
+// Additional tensor extra structs for quantized tensors.
+// These tensors are loaded from files and should not be allocated in scratch --
+// they should always be allocated from the pool. Hence, they do not have an
+// `offset`, which indicate their locations in the scratch buffer.
+struct ggml_tensor_extra_cl_q4_0 {
+    // Quantized values.
+    cl_mem q = nullptr;
+    // Quantized values in image1d_buffer_t.
+    cl_mem q_img = nullptr;
+    // Scales.
+    cl_mem d = nullptr;
+    // Scales in image1d_buffer_t.
+    cl_mem d_img = nullptr;
+    // Size of quantized values.
+    size_t size_q = 0;
+    // Size of scales.
+    size_t size_d = 0;
+
+    ~ggml_tensor_extra_cl_q4_0() {
+        reset();
+    }
+
+    void reset() {
+        // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer.
+        // They must be properly released so that the original buffer can be
+        // properly released to avoid memory leak.
+        if (q != nullptr) {
+            CL_CHECK(clReleaseMemObject(q));
+            q = nullptr;
+        }
+        if (d != nullptr) {
+            CL_CHECK(clReleaseMemObject(d));
+            d = nullptr;
+        }
+        // Currently, q_img and d_img are only initialized when SMALL_ALLOC is
+        // enabled. They point to the images in ggml_backend_opencl_buffer_context.
+        // So, there is no need to release them here.
+        // TODO: initialize them for non SMALL_PATH path, or remove them.
+        q_img = nullptr;
+        d_img = nullptr;
+        size_q = 0;
+        size_d = 0;
+    }
+};
+
+//------------------------------------------------------------------------------
+// Backend API
+//------------------------------------------------------------------------------
+
+//
+// backend
+//
+static const char * ggml_backend_opencl_name(ggml_backend_t backend) {
+    return "OpenCL";
+
+    UNUSED(backend);
+}
+
+static void ggml_backend_opencl_free(ggml_backend_t backend) {
+    ggml_cl2_free();
+
+    GGML_UNUSED(backend);
+}
+
+static void ggml_backend_opencl_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    GGML_UNUSED(backend);
+    GGML_UNUSED(tensor);
+    GGML_UNUSED(data);
+    GGML_UNUSED(offset);
+    GGML_UNUSED(size);
+}
+
+static void ggml_backend_opencl_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_UNUSED(backend);
+    GGML_UNUSED(tensor);
+    GGML_UNUSED(data);
+    GGML_UNUSED(offset);
+    GGML_UNUSED(size);
+}
+
+static bool ggml_backend_opencl_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) {
+    GGML_UNUSED(backend);
+    GGML_UNUSED(src);
+    GGML_UNUSED(dst);
+    return false;
+}
+
+static void ggml_backend_opencl_synchronize(ggml_backend_t backend) {
+    GGML_UNUSED(backend);
+}
+
+static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        ggml_tensor * node = cgraph->nodes[i];
+
+        if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
+            continue;
+        }
+
+        bool ok = ggml_cl_compute_forward(backend, node);
+        if (!ok) {
+            GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
+        }
+        GGML_ASSERT(ok);
+    }
+
+    return GGML_STATUS_SUCCESS;
+}
+
+static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
+    GGML_UNUSED(dev);
+
+    switch (op->op) {
+        case GGML_OP_NONE:
+            return true;
+        case GGML_OP_GET_ROWS:
+            switch (op->src[0]->type) {
+                case GGML_TYPE_F32:
+                case GGML_TYPE_F16:
+                    return true;
+                case GGML_TYPE_Q4_0:
+#ifdef GGML_OPENCL_SOA_Q
+                    // We do not support flattened Q4_0 (and possibly other Q's)
+                    return false;
+#else // GGML_OPENCL_SOA_Q
+                    return true;
+#endif // GGML_OPENCL_SOA_Q
+                default:
+                    return false;
+            }
+        case GGML_OP_CPY:
+        case GGML_OP_DUP:
+        case GGML_OP_CONT:
+            switch (op->src[0]->type) {
+                case GGML_TYPE_F32:
+                    switch (op->type) {
+                        case GGML_TYPE_F16:
+                        case GGML_TYPE_F32:
+                            return true;
+                        default:
+                            return false;
+                    }
+                case GGML_TYPE_F16:
+                    switch (op->type) {
+                        case GGML_TYPE_F16:
+                        case GGML_TYPE_F32:
+                            return true;
+                        default:
+                            return false;
+                    }
+                default:
+                    return false;
+            }
+        case GGML_OP_ADD:
+        case GGML_OP_SCALE:
+        case GGML_OP_MUL:
+            return true;
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(op)) {
+                case GGML_UNARY_OP_GELU:
+                case GGML_UNARY_OP_SILU:
+                case GGML_UNARY_OP_RELU:
+                   return ggml_is_contiguous(op->src[0]);
+                default:
+                    return false;
+            }
+        case GGML_OP_CLAMP:
+        case GGML_OP_SOFT_MAX:
+        case GGML_OP_NORM:
+        case GGML_OP_RMS_NORM:
+            return true;
+        case GGML_OP_MUL_MAT:
+            if (op->src[0]->type == GGML_TYPE_F16) {
+                return true;
+            } else if (op->src[0]->type == GGML_TYPE_F32) {
+                return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
+            } else if (op->src[0]->type == GGML_TYPE_Q4_0 ||
+                       op->src[0]->type == GGML_TYPE_Q6_K) {
+                return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
+            }
+            return false;
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+            return true;
+        case GGML_OP_DIAG_MASK_INF:
+            return op->ne[3] == 1;
+        case GGML_OP_ROPE:
+            return true;
+        default:
+            return false;
+    }
+}
+
+// Forward declaration - implementation appears later in the file.
+static const char * ggml_backend_opencl_buffer_type_get_name(ggml_backend_buffer_type_t buffer_type);
+
+static ggml_guid_t ggml_backend_opencl_guid() {
+    static ggml_guid guid = { 0xde, 0xe0, 0x70, 0xa2, 0x73, 0x4e, 0x4d, 0xbc, 0xb0, 0xc7, 0x4f, 0xd4, 0x6d, 0x4e, 0x90, 0xfe };
+    return &guid;
+}
+
+static ggml_backend_i ggml_backend_opencl_i = {
+    /* .get_name                = */ ggml_backend_opencl_name,
+    /* .free                    = */ ggml_backend_opencl_free,
+    /* .set_tensor_async        = */ NULL,  /* ggml_backend_opencl_set_tensor_async */
+    /* .get_tensor_async        = */ NULL,  /* ggml_backend_opencl_get_tensor_async */
+    /* .cpy_tensor_async        = */ NULL,  /* ggml_backend_opencl_cpy_tensor_async */
+    /* .synchronize             = */ NULL,  /* ggml_backend_opencl_synchronize */
+    /* .graph_plan_create       = */ NULL,
+    /* .graph_plan_free         = */ NULL,
+    /* .graph_plan_update       = */ NULL,
+    /* .graph_plan_compute      = */ NULL,
+    /* .graph_compute           = */ ggml_backend_opencl_graph_compute,
+    /* .event_record            = */ NULL,
+    /* .event_wait              = */ NULL,
+};
+
+ggml_backend_t ggml_backend_opencl_init(void) {
+    ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_opencl_reg(), 0);
+    ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev);
+
+    ggml_backend_t backend = new ggml_backend {
+        /* .guid      = */ ggml_backend_opencl_guid(),
+        /* .interface = */ ggml_backend_opencl_i,
+        /* .device    = */ dev,
+        /* .context   = */ backend_ctx
+    };
+
+    return backend;
+}
+
+bool ggml_backend_is_opencl(ggml_backend_t backend) {
+    return backend && backend->iface.get_name == ggml_backend_opencl_name;
+}
+
+//
+// buffer
+//
+struct ggml_backend_opencl_buffer_context {
+    // A buffer context can hold multiple cl_mem objects. This is for flattening
+    // quantized weights and should be used with GGML_OPENCL_SMALL_ALLOC where
+    // each tensor is allocated a separate buffer. When flattening is enabled
+    // with small allocation, each tensor is backed by two cl_mem objects (for
+    // quants and scales) packed into a backend_opencl_buffer.
+    ggml_backend_opencl_buffer_context(cl_mem buf)
+        : name("OpenCL") {
+        buffer.push_back(buf);
+    }
+
+    ~ggml_backend_opencl_buffer_context() {
+        for (cl_mem buf : buffer) {
+            CL_CHECK(clReleaseMemObject(buf));
+        }
+        for (cl_mem im : img) {
+            CL_CHECK(clReleaseMemObject(im));
+        }
+
+        // Delete all extras to trigger their destructors
+        for (ggml_tensor_extra_cl * e : temp_tensor_extras) {
+            delete e;
+        }
+        for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) {
+            delete e;
+        }
+        for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0) {
+            delete e;
+        }
+        for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) {
+            delete e;
+        }
+    }
+
+    ggml_tensor_extra_cl * ggml_opencl_alloc_temp_tensor_extra() {
+        ggml_tensor_extra_cl * extra;
+        if (temp_tensor_extras.empty()) {
+            extra = new ggml_tensor_extra_cl();
+        } else {
+            extra = temp_tensor_extras.back();
+            temp_tensor_extras.pop_back();
+        }
+
+        temp_tensor_extras_in_use.push_back(extra);
+
+        extra->reset();
+        return extra;
+    }
+
+    ggml_tensor_extra_cl_q4_0 * ggml_opencl_alloc_temp_tensor_extra_q4_0() {
+        ggml_tensor_extra_cl_q4_0 * extra;
+        if (temp_tensor_extras_q4_0.empty()) {
+            extra = new ggml_tensor_extra_cl_q4_0();
+        } else {
+            extra = temp_tensor_extras_q4_0.back();
+            temp_tensor_extras_q4_0.pop_back();
+        }
+
+        temp_tensor_extras_q4_0_in_use.push_back(extra);
+
+        extra->reset();
+        return extra;
+    }
+
+    void reset() {
+        for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) {
+            temp_tensor_extras.push_back(e);
+        }
+        temp_tensor_extras_in_use.clear();
+
+        for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) {
+            temp_tensor_extras_q4_0.push_back(e);
+        }
+        temp_tensor_extras_q4_0_in_use.clear();
+    }
+
+    // Pools for extras. Available extras are in `temp_tensor_extras`. Extras
+    // being used are in `temp_tensor_extras_in_use`. At the first run, new
+    // extras get created and put in `in_use`. When the buffer is reset via
+    // the `reset` callback, all extras in `in_use` get moved to available extras
+    // for reuse.
+    std::vector temp_tensor_extras;
+    std::vector temp_tensor_extras_in_use;
+    std::vector temp_tensor_extras_q4_0;
+    std::vector temp_tensor_extras_q4_0_in_use;
+
+    // The buffer_context is initially created by ggml_backend_buft_alloc_buffer
+    // before any tensor is initialized (at the beginning of alloc_tensor_range).
+    // Hence, there is alway a buffer object in this vector. When each tensor is
+    // being initialized, this original buffer object will be released if both
+    // flattening and small allocation are enabled, and additional buffer
+    // objects will be created in init_tensor to represent flattened quantized
+    // weights.
+    std::vector buffer;
+    // These are image1d_buffer_t objects that wrap around the quants and scales.
+    // For Q4_0 quantization, there should be two of them - one for quants and
+    // one for scales. They should be populated only when flattening and small
+    // allocation are enabled.
+    std::vector img;
+    std::string name;
+};
+
+static void * const cl_ptr_base = (void *)(uintptr_t) 0x1000;
+
+static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
+    delete ctx;
+}
+
+static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) {
+    return cl_ptr_base;
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_opencl_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
+    ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
+
+    ggml_cl2_init(buffer->buft->device);
+
+    if (tensor->view_src != nullptr) {
+        GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
+
+        ggml_tensor_extra_cl * view_extra = (ggml_tensor_extra_cl *) tensor->view_src->extra;
+        GGML_ASSERT(view_extra && "view_extra is nullptr?");
+
+        // Reuse extra of the parent tensor. The offset of this view tensor
+        // becomes `extra->offset + view_offs` and needs to be calculated when
+        // it is used. This changes is needed because of the change to
+        // ggml_alloc.c in https://github.com/ggerganov/llama.cpp/pull/7640.
+        // `buffer` passed in here will always be `tensor->buffer`. It is OK
+        // to allocate extras from the same buffer context for ordinary
+        // intermediate tensors. But for views into kv cache tensors, doing so
+        // would mess up the extras used by kv cache.
+        // Before #7640, `buffer` is for intermediate tensors, which is always
+        // different from that of kv cache tensors.
+        //
+        // NB: now extra->offset no longer accounts for view_offs.
+        // NB: this should not apply to weight tensors (for end-to-end runs, but
+        //     may apply for test-backend-ops).
+        // FIXME: if any unexpected results are seen, double check the offset -
+        // there could be other places that need fix.
+        tensor->extra = view_extra;
+    } else {
+        {
+            size_t offset = (char *)tensor->data - (char *)cl_ptr_base;
+
+            ggml_tensor_extra_cl * extra = ctx->ggml_opencl_alloc_temp_tensor_extra();
+            extra->offset = offset;
+            extra->data_device = ctx->buffer[0];
+            extra->actual_size = ggml_nbytes(tensor);
+
+            tensor->extra = extra;
+        }
+    }
+}
+
+// The optimized gemm and gemv kernels are used for large matrices without batch.
+// tensor is the quantized weights matrix.
+inline bool use_adreno_kernels(const ggml_tensor *tensor) {
+    return tensor->ne[0] >= 512 && tensor->ne[1] >= 512 &&
+            tensor->ne[2] == 1 && tensor->ne[3] == 1;
+}
+
+static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device);
+
+    cl_context context = backend_ctx->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+#ifdef GGML_OPENCL_SOA_Q
+    // We separate the quantized bits and scale from block_q4_0 by using an
+    // additional kernel, where each thread handles a block. We first read the
+    // original weights into a temporary buffer, then create two separate
+    // buffers for quantized bits and scales, which are then populated by the
+    // conversion kernel.
+    if (tensor->type == GGML_TYPE_Q4_0) {
+        // Tensors should have been preallocated, therefore they should
+        // already have ggml_tensor_extra_cl as extra.
+        ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
+        GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
+
+        // Allocate the new extra and create aliases from the original.
+        ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
+        ggml_tensor_extra_cl_q4_0 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q4_0();
+
+        size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
+        size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2;
+        GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size");
+
+        cl_int err;
+        cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
+            ggml_nbytes(tensor), NULL, &err);
+        CL_CHECK(err);
+        CL_CHECK(clEnqueueWriteBuffer(
+            queue, data_device, CL_TRUE, 0,
+            ggml_nbytes(tensor), data, 0, NULL, NULL));
+
+        // We consider the specified offset arg as always, although For weights
+        // the offset arg should be 0 (we do not assert this).
+        //GGML_ASSERT(offset == 0);
+
+        // We create subbuffers from the original tensor buffer for scales and
+        // quants - i.e., scales and quants are aliases into the buffer obejct
+        // that backs the original tensor. This is a cleaner way to adapt to the
+        // new memory management.
+        // In the old code, we allocate new buffers for scales and quants
+        // respectively, which could still be done but would result in double
+        // allocation; properly deallocating the preallocated buffer that backs
+        // the tensors is tricky and would leak the backend specific information
+        // into the general backend code.
+        // Does this create misaligned subbuffers (alignment is 1024) in certain
+        // cases ?
+        cl_buffer_region region;
+
+        // The original tensor memory is divided into scales and quants, i.e.,
+        // we first store scales, then quants.
+        // Create subbuffer for scales.
+        region.origin = extra_orig->offset + tensor->view_offs + offset;
+        region.size = size_d;
+        extra->d = clCreateSubBuffer(
+            extra_orig->data_device, CL_MEM_READ_WRITE,
+            CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
+        CL_CHECK(err);
+
+        // Create subbuffer for quants.
+        region.origin = extra_orig->offset + tensor->view_offs + offset + size_d;
+        region.size = size_q;
+        extra->q = clCreateSubBuffer(
+            extra_orig->data_device, CL_MEM_READ_WRITE,
+            CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err);
+        CL_CHECK(err);
+
+        //cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0;
+    #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+        cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0;
+
+        // The optimized kernels need weights in natural order, so unshuffle.
+        if (use_adreno_kernels(tensor)) {
+            kernel = backend_ctx->kernel_convert_block_q4_0_noshuffle;
+        }
+    #else
+        cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0;
+    #endif // GGML_OPENCL_USE_ADRENO_KERNELS
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d));
+
+        size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
+        size_t local_work_size[] = {64, 1, 1};
+
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+        CL_CHECK(clWaitForEvents(1, &evt));
+        CL_CHECK(clReleaseMemObject(data_device));
+
+        tensor->extra = extra;
+
+        // transpose the weights and scales
+    #ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+        // Only do transpose for large, non batched matrix
+        // TODO: use preallocated images instead of sub-buffer then image
+        if (use_adreno_kernels(tensor)) {
+        // <----------------------------------------------------------------------------------> //
+        // start transpose
+        // <----------------------------------------------------------------------------------> //
+        int M = tensor->ne[1];   // ne01
+        int K = tensor->ne[0];   // ne00
+
+        // transpose is out of place, so we need to allocate transposed buffers
+        // <----------------------------------------------------------------------------------> //
+        // use sub_buffer of max buffer size instead
+
+        size_t q_size_bytes = K * M / 8 * sizeof(float);
+        cl_buffer_region region;
+        region.origin = 0;
+        region.size = q_size_bytes;
+        cl_mem qT_d = clCreateSubBuffer(
+            backend_ctx->A_q_d_max,
+            0,
+            CL_BUFFER_CREATE_TYPE_REGION,
+            ®ion,
+            &err);
+        // cl_mem qT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, q_size_bytes, NULL, &err);
+        CL_CHECK(err);
+
+        // size_t d_size_bytes = M * (K / 32) / 2 * sizeof(float);
+        size_t d_size_bytes = M * (K / 32) * 2;
+        region.origin = 0;
+        region.size = d_size_bytes;
+        cl_mem dT_d = clCreateSubBuffer(
+            backend_ctx->A_s_d_max,
+            0,
+            CL_BUFFER_CREATE_TYPE_REGION,
+            ®ion,
+            &err);
+        // cl_mem dT_d = clCreateBuffer(context, CL_MEM_READ_WRITE, d_size_bytes, NULL, &err);
+        CL_CHECK(err);
+
+        // <----------------------------------------------------------------------------------> //
+
+
+        // create images from the buffers
+        // <----------------------------------------------------------------------------------> //
+        cl_mem q_d_image1D;
+        cl_mem d_d_image1D;
+        cl_mem qT_d_image1D;
+        cl_mem dT_d_image1D;
+
+        cl_image_format img_fmt_1d = { CL_RGBA, CL_FLOAT };
+        cl_image_desc img_desc_1d;
+
+        memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+        img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+        img_desc_1d.image_width = M * K / 8 / 4;
+        img_desc_1d.buffer = extra->q;
+        q_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);
+        CL_CHECK(err);
+
+        img_fmt_1d = { CL_RGBA, CL_FLOAT };
+        memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+        img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+        img_desc_1d.image_width = M * K / 8 / 4;
+        img_desc_1d.buffer = qT_d;
+        qT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);
+        CL_CHECK(err);
+
+        img_fmt_1d = { CL_RGBA, CL_FLOAT };
+        memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+        img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+        img_desc_1d.image_width = M * K / 32 / 4 / 2;
+        img_desc_1d.buffer = extra->d;
+        d_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);
+        CL_CHECK(err);
+
+        img_fmt_1d = { CL_RGBA, CL_FLOAT };
+        memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+        img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+        img_desc_1d.image_width = M * K / 32 / 4 / 2;
+        img_desc_1d.buffer = dT_d;
+        dT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err);
+        CL_CHECK(err);
+        // <----------------------------------------------------------------------------------> //
+
+        // set up and call the transpose kernels
+        // <----------------------------------------------------------------------------------> //
+        // weights
+        int height_q = M / 8;
+        int width_q = K / 8 / 4;
+        kernel = backend_ctx->kernel_transpose_16;
+
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_d_image1D));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qT_d_image1D));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int),    &height_q));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int),    &width_q));
+
+        size_t local_size_q[3] = {4, 16, 1};
+        size_t global_size_q[3] = {static_cast(width_q), static_cast(height_q), 1};
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_q, local_size_q, 0, NULL, &evt));
+        CL_CHECK(clWaitForEvents(1, &evt));
+
+        // scales
+        int height_s = M / 8;
+        int width_s = K / 32 / 8;
+
+        kernel = backend_ctx->kernel_transpose_16;
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &d_d_image1D));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &dT_d_image1D));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_s));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_s));
+
+        size_t local_size_s[3] = {4, 16, 1};
+        size_t global_size_s[3] = {static_cast(width_s), static_cast(height_s), 1};
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_s, local_size_s, 0, NULL, &evt));
+        CL_CHECK(clWaitForEvents(1, &evt));
+        // <----------------------------------------------------------------------------------> //
+
+        // copy transposed buffer contents to original buffers
+        // <----------------------------------------------------------------------------------> //
+        // weights
+        CL_CHECK(clEnqueueCopyBuffer(queue, qT_d, extra->q, 0, 0, q_size_bytes, 0, NULL, &evt));
+        CL_CHECK(clWaitForEvents(1, &evt));
+
+        // scales
+        CL_CHECK(clEnqueueCopyBuffer(queue, dT_d, extra->d, 0, 0, d_size_bytes, 0, NULL, &evt));
+        CL_CHECK(clWaitForEvents(1, &evt));
+        // <----------------------------------------------------------------------------------> //
+
+        // deallocate transpose buffers
+        // <----------------------------------------------------------------------------------> //
+        CL_CHECK(clReleaseMemObject(qT_d));
+        CL_CHECK(clReleaseMemObject(dT_d));
+
+        // deallocate temporary images
+        CL_CHECK(clReleaseMemObject(q_d_image1D));
+        CL_CHECK(clReleaseMemObject(d_d_image1D));
+        CL_CHECK(clReleaseMemObject(qT_d_image1D));
+        CL_CHECK(clReleaseMemObject(dT_d_image1D));
+        // <----------------------------------------------------------------------------------> //
+        // end transpose
+        // <----------------------------------------------------------------------------------> //
+        }
+    #endif // GGML_OPENCL_USE_ADRENO_KERNELS
+
+        return;
+    }
+#endif // GGML_OPENCL_SOA_Q
+
+    ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra;
+    GGML_ASSERT(extra);
+
+    CL_CHECK(clEnqueueWriteBuffer(
+        queue, extra->data_device, CL_TRUE, extra->offset + offset,
+        size, data, 0, NULL, NULL));
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(tensor->extra);
+
+    ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device);
+
+    cl_context context = backend_ctx->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    // Make sure all previously submitted commands are finished.
+    CL_CHECK(clFinish(queue));
+
+#ifdef GGML_OPENCL_SOA_Q
+    // In end-to-end runs, get_tensor is usually used to get back the logits,
+    // where we can simply do clEnqueueReadBuffer since they are f32.
+    // However, in test-backend-ops, the GPU graph is copied to the CPU backend,
+    // which requires reading back quantized weight tensors.
+    // To properly support this, we need to restore block_q4_0 struct arrays
+    // from the flattened buffers.
+    if (tensor->type == GGML_TYPE_Q4_0) {
+        ggml_tensor_extra_cl_q4_0 * extra = (ggml_tensor_extra_cl_q4_0 *)tensor->extra;
+
+        cl_int err;
+        cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
+            ggml_nbytes(tensor), NULL, &err);
+        CL_CHECK(err);
+
+        cl_kernel kernel = backend_ctx->kernel_restore_block_q4_0;
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device));
+
+        size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
+        size_t local_work_size[] = {1, 1, 1};
+
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
+            global_work_size, local_work_size, 0, NULL, &evt));
+        CL_CHECK(clWaitForEvents(1, &evt));
+        CL_CHECK(clEnqueueReadBuffer(
+            queue, data_device, CL_TRUE, offset,
+            size, data, 0, NULL, NULL));
+        CL_CHECK(clReleaseMemObject(data_device));
+        return;
+    }
+#endif // GGML_OPENCL_SOA_Q
+
+    ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra;
+
+    CL_CHECK(clEnqueueReadBuffer(
+        queue, extra->data_device, CL_TRUE, extra->offset + tensor->view_offs + offset,
+        size, data, 0, NULL, NULL));
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_opencl_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+    ggml_backend_dev_t dev = buffer->buft->device;
+    ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev);
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
+    for (cl_mem buf : ctx->buffer) {
+        CL_CHECK(clEnqueueFillBuffer(queue, buf, &value, sizeof(value), 0, buffer->size, 0, NULL, NULL));
+    }
+    CL_CHECK(clFinish(queue));
+}
+
+static void ggml_backend_opencl_buffer_reset(ggml_backend_buffer_t buffer) {
+    ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
+    ctx->reset();
+}
+
+static ggml_backend_buffer_i ggml_backend_opencl_buffer_interface = {
+    /* .free_buffer     = */ ggml_backend_opencl_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_opencl_buffer_get_base,
+    /* .init_tensor     = */ ggml_backend_opencl_buffer_init_tensor,
+    /* .memset_tensor   = */ NULL,
+    /* .set_tensor      = */ ggml_backend_opencl_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_opencl_buffer_get_tensor,
+    /* .cpy_tensor      = */ NULL,
+    /* .clear           = */ ggml_backend_opencl_buffer_clear,
+    /* .reset           = */ ggml_backend_opencl_buffer_reset,
+};
+
+//
+// buffer type
+//
+
+static const char * ggml_backend_opencl_buffer_type_get_name(ggml_backend_buffer_type_t buffer_type) {
+    return "OpenCL";
+
+    GGML_UNUSED(buffer_type);
+}
+
+static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buffer_type, size_t size) {
+    ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer_type->device);
+
+    // clCreateBuffer returns -61 for size 0
+    size = std::max(size, (size_t)1);
+
+    cl_int err;
+    cl_mem mem = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, size, NULL, &err);
+    if (err != CL_SUCCESS) {
+        GGML_LOG_INFO("%s: failed to allocate %.2f MiB\n", __func__, size / 1024.0 / 1024.0);
+        return nullptr;
+    }
+
+    ggml_backend_opencl_buffer_context * ctx = new ggml_backend_opencl_buffer_context(mem);
+
+    return ggml_backend_buffer_init(buffer_type, ggml_backend_opencl_buffer_interface, ctx, size);
+}
+
+static size_t ggml_backend_opencl_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) {
+    // FIXME: not thread safe, device may not be initialized yet
+    static cl_uint alignment = -1;
+    if (alignment == (cl_uint)-1) {
+        ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device);
+        alignment = backend_ctx->alignment;
+    }
+    return alignment;
+}
+
+static size_t ggml_backend_opencl_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) {
+    static size_t max_size = -1;
+    if (max_size == (size_t)-1) {
+        ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device);
+        max_size = backend_ctx->max_alloc_size;
+    }
+    return max_size;
+}
+
+static bool ggml_backend_opencl_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
+    return ggml_backend_is_opencl(backend);
+
+    UNUSED(buft);
+}
+
+static ggml_backend_buffer_type_i ggml_backend_opencl_buffer_type_interface = {
+    /* .get_name         = */ ggml_backend_opencl_buffer_type_get_name,
+    /* .alloc_buffer     = */ ggml_backend_opencl_buffer_type_alloc_buffer,
+    /* .get_alignment    = */ ggml_backend_opencl_buffer_type_get_alignment,
+    /* .get_max_size     = */ ggml_backend_opencl_buffer_type_get_max_size,
+    /* .get_alloc_size   = */ NULL,
+    /* .is_host          = */ NULL,
+};
+
+ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type() {
+    static ggml_backend_buffer_type buffer_type = {
+        /* .iface   = */ ggml_backend_opencl_buffer_type_interface,
+        /* .device  = */ &g_ggml_backend_opencl_device,
+        /* .context = */ nullptr,
+    };
+
+    return &buffer_type;
+}
+
+//
+// backend device
+//
+
+static const char * ggml_backend_opencl_device_get_name(ggml_backend_dev_t dev) {
+    return "GPUOpenCL";
+
+    GGML_UNUSED(dev);
+}
+
+static const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_t dev) {
+    ggml_backend_opencl_device_context *dev_ctx = (ggml_backend_opencl_device_context *) dev->context;
+    return dev_ctx->device_name.c_str();
+}
+
+static void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
+    *free = 1;
+    *total = 1;
+
+    GGML_UNUSED(dev);
+}
+
+static enum ggml_backend_dev_type ggml_backend_opencl_device_get_type(ggml_backend_dev_t dev) {
+    return GGML_BACKEND_DEVICE_TYPE_GPU;
+
+    GGML_UNUSED(dev);
+}
+
+static void ggml_backend_opencl_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
+    props->name        = ggml_backend_opencl_device_get_name(dev);
+    props->description = ggml_backend_opencl_device_get_description(dev);
+    props->type        = ggml_backend_opencl_device_get_type(dev);
+    ggml_backend_opencl_device_get_memory(dev, &props->memory_free, &props->memory_total);
+    props->caps = ggml_backend_dev_caps {
+        /* .async                 = */ false,
+        /* .host_buffer           = */ false,
+        /* .buffer_from_host_ptr  = */ false,
+        /* .events                = */ false,
+    };
+}
+
+static ggml_backend_t ggml_backend_opencl_device_init(ggml_backend_dev_t dev, const char * params) {
+    ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(dev);
+
+    ggml_backend_t backend = new ggml_backend {
+        /* .guid      = */ ggml_backend_opencl_guid(),
+        /* .interface = */ ggml_backend_opencl_i,
+        /* .device    = */ dev,
+        /* .context   = */ backend_ctx,
+    };
+
+    return backend;
+
+    GGML_UNUSED(params);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_opencl_device_get_buffer_type(ggml_backend_dev_t dev) {
+    return ggml_backend_opencl_buffer_type();
+
+    GGML_UNUSED(dev);
+}
+
+static ggml_backend_buffer_t ggml_backend_opencl_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
+    GGML_UNUSED(dev);
+    GGML_UNUSED(ptr);
+    GGML_UNUSED(size);
+    GGML_UNUSED(max_tensor_size);
+    return nullptr;
+}
+
+static bool ggml_backend_opencl_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
+    return ggml_opencl_supports_op(dev, op);
+}
+
+static bool ggml_backend_opencl_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_name == ggml_backend_opencl_buffer_type_get_name;
+
+    GGML_UNUSED(dev);
+}
+
+static struct ggml_backend_device_i ggml_backend_opencl_device_i = {
+    /* .get_name             = */ ggml_backend_opencl_device_get_name,
+    /* .get_description      = */ ggml_backend_opencl_device_get_description,
+    /* .get_memory           = */ ggml_backend_opencl_device_get_memory,
+    /* .get_type             = */ ggml_backend_opencl_device_get_type,
+    /* .get_props            = */ ggml_backend_opencl_device_get_props,
+    /* .init_backend         = */ ggml_backend_opencl_device_init,
+    /* .get_buffer_type      = */ ggml_backend_opencl_device_get_buffer_type,
+    /* .get_host_buffer_type = */ NULL,
+    /* .buffer_from_host_ptr = */ ggml_backend_opencl_device_buffer_from_ptr,
+    /* .supports_op          = */ ggml_backend_opencl_device_supports_op,
+    /* .supports_buft        = */ ggml_backend_opencl_device_supports_buft,
+    /* .offload_op           = */ NULL,
+    /* .event_new            = */ NULL,
+    /* .event_free           = */ NULL,
+    /* .event_synchronize    = */ NULL,
+};
+
+// Backend registry
+
+static const char * ggml_backend_opencl_reg_get_name(ggml_backend_reg_t reg) {
+    return "OpenCL";
+
+    GGML_UNUSED(reg);
+}
+
+static size_t ggml_backend_opencl_reg_device_count(ggml_backend_reg_t reg) {
+    return ggml_backend_opencl_n_devices;
+
+    GGML_UNUSED(reg);
+}
+
+static ggml_backend_dev_t ggml_backend_opencl_reg_device_get(ggml_backend_reg_t reg, size_t index) {
+    GGML_ASSERT(index == 0);
+
+    return &g_ggml_backend_opencl_device;
+
+    GGML_UNUSED(reg);
+    GGML_UNUSED(index);
+}
+
+static struct ggml_backend_reg_i ggml_backend_opencl_reg_i = {
+    /* .get_name         = */ ggml_backend_opencl_reg_get_name,
+    /* .device_count     = */ ggml_backend_opencl_reg_device_count,
+    /* .device_get       = */ ggml_backend_opencl_reg_device_get,
+    /* .get_proc_address = */ NULL,
+};
+
+ggml_backend_reg_t ggml_backend_opencl_reg(void) {
+    // TODO: make this thread-safe somehow?
+    static ggml_backend_reg reg;
+    static bool initialized = false;
+
+    if (!initialized) {
+        reg = ggml_backend_reg {
+            /* .api_version = */ GGML_BACKEND_API_VERSION,
+            /* .iface   = */ ggml_backend_opencl_reg_i,
+            /* .context = */ NULL,
+        };
+
+        g_ggml_backend_opencl_device = ggml_backend_device {
+            /* .iface   = */ ggml_backend_opencl_device_i,
+            /* .reg     = */ ®,
+            /* .context = */ &g_ggml_ctx_dev_main,
+        };
+
+        ggml_cl2_init(&g_ggml_backend_opencl_device);
+
+        initialized = true;
+    }
+
+    return ®
+}
+
+GGML_BACKEND_DL_IMPL(ggml_backend_opencl_reg)
+
+//------------------------------------------------------------------------------
+// Debugging utils
+//------------------------------------------------------------------------------
+#if 0
+#define QK4_0 32
+typedef struct {
+    ggml_fp16_t d;          // delta
+    uint8_t qs[QK4_0 / 2];  // nibbles / quants
+} block_q4_0;
+static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2,
+    "wrong q4_0 block size/padding");
+
+#include 
+#ifdef __cplusplus
+#include "half.hpp"
+#endif
+
+static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tensor) {
+    void * buf = malloc(ggml_nbytes(tensor));
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+#ifdef GGML_OPENCL_SOA_Q
+    void * buf_q;
+    void * buf_d;
+#endif
+
+#ifdef GGML_USE_OPENCL
+    // Make sure everything is done.
+    CL_CHECK(clFinish(queue));
+
+#ifdef GGML_OPENCL_SOA_Q
+    if (tensor->type == GGML_TYPE_Q4_0) {
+        ggml_tensor_extra_cl_q4_0 * extra = (ggml_tensor_extra_cl_q4_0 *) tensor->extra;
+        GGML_ASSERT(extra);
+
+        size_t size_q = ggml_nelements(tensor)/QK4_0 * QK4_0/2;
+        size_t size_d = ggml_nelements(tensor)/QK4_0 * sizeof(ggml_fp16_t);
+        GGML_ASSERT(size_q + size_d == ggml_nbytes(tensor));
+        buf_q = malloc(size_q);
+        buf_d = malloc(size_d);
+
+        CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL));
+        CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_d, buf_d, 0, NULL, NULL));
+        CL_CHECK(clFinish(queue));
+    } else {
+        // Read out the tensor from GPU memory.
+        ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra;
+        GGML_ASSERT(extra);
+
+        CL_CHECK(clEnqueueReadBuffer(queue, extra->data_device, CL_TRUE,
+        extra->offset, ggml_nbytes(tensor), buf, 0, NULL, NULL));
+        CL_CHECK(clFinish(queue));
+    }
+#else
+    // Read out the tensor from GPU memory.
+    ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra;
+    GGML_ASSERT(extra);
+
+    CL_CHECK(clEnqueueReadBuffer(queue, extra->data_device, CL_TRUE,
+        extra->offset, ggml_nbytes(tensor), buf, 0, NULL, NULL));
+    CL_CHECK(clFinish(queue));
+#endif // GGML_OPENCL_SOA_Q
+#endif // GGML_USE_OPENCL
+
+    // Open file and dump.
+    char fname[512];
+    sprintf(fname, "./tensor-dumps/%s.txt", tensor->name);
+    FILE * f = fopen(fname, "w");
+    if (!f) {
+        printf("Failed to open %s\n", fname);
+        return;
+    }
+
+    if (tensor->type == GGML_TYPE_F32) {
+        float * data = (float *) buf;
+        for (int i = 0; i < ggml_nelements(tensor); ++i) {
+            if (isnan(data[i])) {
+                printf("NaN found: %s\n", tensor->name);
+                break;
+            }
+            fprintf(f, "%f\n", data[i]);
+        }
+    } else if (tensor->type == GGML_TYPE_I32) {
+        int * data = (int *) buf;
+        for (int i = 0; i < ggml_nelements(tensor); ++i) {
+            if (isnan(data[i])) {
+                printf("NaN found: %s\n", tensor->name);
+                break;
+            }
+            fprintf(f, "%d\n", data[i]);
+        }
+    } else if (tensor->type == GGML_TYPE_F16) {
+#ifdef __cplusplus
+        half_float::half * data = (half_float::half *) buf;
+        for (int i = 0; i < ggml_nelements(tensor); ++i) {
+            if (std::isnan(data[i])) {
+                printf("NaN found: %s\n", tensor->name);
+                break;
+            }
+            fprintf(f, "%f\n", float(data[i]));
+        }
+#endif
+    } else if (tensor->type == GGML_TYPE_Q4_0) {
+#ifdef GGML_OPENCL_SOA_Q
+        ggml_fp16_t * data_d = (ggml_fp16_t *)buf_d;
+        unsigned char * data_q = (unsigned char *)buf_q;
+
+        for (int i = 0; i < ggml_nelements(tensor)/QK4_0; ++i) {
+            fprintf(f, "%04x, ", data_d[i]);
+            for (int k = 0; k < QK4_0/2; ++k) {
+                fprintf(f, "%02x, ", data_q[k]);
+            }
+            fprintf(f, "\n");
+            data_q += QK4_0/2;
+        }
+        free(buf_d);
+        free(buf_q);
+#else
+        block_q4_0 * data = (block_q4_0 *) buf;
+        for (int i = 0; i < ggml_nelements(tensor)/QK4_0; ++i) {
+            fprintf(f, "%04x, ", data[i].d);
+            for (int k = 0; k < QK4_0/2; ++k) {
+                fprintf(f, "%02x, ", data[i].qs[k]);
+            }
+            fprintf(f, "\n");
+        }
+#endif // GGML_OPENCL_SOA_Q
+    }
+    free(buf);
+    fflush(f);
+    fclose(f);
+}
+#else
+#define dump_tensor(tensor)
+#endif
+
+//------------------------------------------------------------------------------
+// Profiling utility
+//------------------------------------------------------------------------------
+#ifdef GGML_OPENCL_PROFILING
+void populateProfilingInfo(
+        ProfilingInfo& info, cl_event evt, cl_kernel kernel,
+        size_t global_size[3], size_t local_size[3],
+        const ggml_tensor * tensor) {
+    cl_ulong start;
+    cl_ulong end;
+    CL_CHECK(clWaitForEvents(1, &evt));
+    CL_CHECK(clGetEventProfilingInfo(
+        evt, CL_PROFILING_COMMAND_START, sizeof(cl_ulong), &start, NULL));
+    CL_CHECK(clGetEventProfilingInfo(
+        evt, CL_PROFILING_COMMAND_END, sizeof(cl_ulong), &end, NULL));
+
+    char kernel_name[512];
+    CL_CHECK(clGetKernelInfo(kernel, CL_KERNEL_FUNCTION_NAME,
+        sizeof(kernel_name), kernel_name, NULL));
+
+    info.duration_ns = end - start;
+    info.op_name = tensor->name;
+    info.kernel_name = kernel_name;
+    info.local_size[0]  = local_size[0];
+    info.local_size[1]  = local_size[1];
+    info.local_size[2]  = local_size[2];
+    info.global_size[0] = global_size[0];
+    info.global_size[1] = global_size[1];
+    info.global_size[2] = global_size[2];
+    info.output_size[0] = tensor->ne[0];
+    info.output_size[1] = tensor->ne[1];
+    info.output_size[2] = tensor->ne[2];
+    info.output_size[3] = tensor->ne[3];
+}
+#endif
+
+//------------------------------------------------------------------------------
+// Ops
+//------------------------------------------------------------------------------
+
+static bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
+    const int64_t ne10 = src1->ne[0];
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
+
+    // TODO: find the optimal values for these
+    return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
+            src1->type == GGML_TYPE_F32 &&
+             dst->type == GGML_TYPE_F32 &&
+            (ne0 >= 32 && ne1 >= 32 && ne10 >= 32);
+}
+
+static void ggml_cl_nop(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    UNUSED(backend);
+    UNUSED(src0);
+    UNUSED(src1);
+    UNUSED(dst);
+}
+
+static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(src1);
+    GGML_ASSERT(src1->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    const int      ne00 = src0 ? src0->ne[0] : 0;
+    const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
+    const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
+    const int      ne10 = src1 ? src1->ne[0] : 0;
+    const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
+    const int      ne11 = src1 ? src1->ne[1] : 0;
+    const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
+    const cl_ulong nb1  = dst  ?  dst->nb[1] : 0;
+    const cl_ulong nb2  = dst  ?  dst->nb[2] : 0;
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offset1 = extra1->offset + src1->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    cl_kernel kernel;
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            kernel = backend_ctx->kernel_get_rows_f32;
+            break;
+        case GGML_TYPE_F16:
+            kernel = backend_ctx->kernel_get_rows_f16;
+            break;
+        case GGML_TYPE_Q4_0:
+            kernel = backend_ctx->kernel_get_rows_q4_0;
+            break;
+        default:
+            GGML_ASSERT(false && "not implemented");
+    }
+
+    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &nb01));
+    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb02));
+    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne10));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb10));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1));
+    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2));
+
+    size_t global_work_size[] = {(size_t)ne10, (size_t)ne11, 1};
+    size_t local_work_size[] = {1, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+    cl_event evt;
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+    g_profiling_info.emplace_back();
+    populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+}
+
+static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(src1);
+    GGML_ASSERT(src1->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    const int  ne00 = src0 ? src0->ne[0] : 0;
+    const int  ne01 = src0 ? src0->ne[1] : 0;
+    const int  ne02 = src0 ? src0->ne[2] : 0;
+    const int  ne03 = src0 ? src0->ne[3] : 0;
+
+    const cl_ulong nb00 = src0 ? src0->nb[0] : 0;
+    const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
+    const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
+    const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
+
+    const int  ne10 = src1 ? src1->ne[0] : 0;
+    const int  ne11 = src1 ? src1->ne[1] : 0;
+    const int  ne12 = src1 ? src1->ne[2] : 0;
+    const int  ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
+
+    const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
+    const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
+    const cl_ulong nb12 = src1 ? src1->nb[2] : 0;
+    const cl_ulong nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
+
+    const int  ne0  = dst ? dst->ne[0] : 0;
+    const int  ne1  = dst ? dst->ne[1] : 0;
+    const int  ne2  = dst ? dst->ne[2] : 0;
+    const int  ne3  = dst ? dst->ne[3] : 0;
+
+    const cl_ulong nb0  = dst ? dst->nb[0] : 0;
+    const cl_ulong nb1  = dst ? dst->nb[1] : 0;
+    const cl_ulong nb2  = dst ? dst->nb[2] : 0;
+    const cl_ulong nb3  = dst ? dst->nb[3] : 0;
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offset1 = extra1->offset + src1->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    bool bcast_row = false;
+    cl_kernel kernel;
+
+    if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
+        GGML_ASSERT(ggml_is_contiguous(src0));
+
+        // src1 is a row
+        GGML_ASSERT(ne11 == 1);
+
+        bcast_row = true;
+        int ne = ne00 / 4;
+        kernel = backend_ctx->kernel_add_row;
+
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extra1->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
+        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
+        CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &ne));
+    } else {
+        kernel = backend_ctx->kernel_add;
+
+        CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+        CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+        CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+        CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+        CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+        CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+        CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+        CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+        CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne03));
+        CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
+        CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
+        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
+        CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
+        CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne10));
+        CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne11));
+        CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &ne12));
+        CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &ne13));
+        CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
+        CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
+        CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
+        CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
+        CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &ne0));
+        CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),      &ne1));
+        CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int),      &ne2));
+        CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int),      &ne3));
+        CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
+        CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
+        CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
+        CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
+    }
+
+    if (bcast_row) {
+        int n = ggml_nelements(dst)/4;
+        size_t global_work_size[] = {(size_t)n, 1, 1};
+        size_t local_work_size[] = {64, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+        g_profiling_info.emplace_back();
+        populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+    } else {
+        unsigned int nth = MIN(64, ne0);
+        size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03};
+        size_t local_work_size[] = {nth, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+        g_profiling_info.emplace_back();
+        populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+    }
+}
+
+static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(src1);
+    GGML_ASSERT(src1->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    const int ne00 = src0 ? src0->ne[0] : 0;
+    const int ne01 = src0 ? src0->ne[1] : 0;
+    const int ne02 = src0 ? src0->ne[2] : 0;
+    const int ne03 = src0 ? src0->ne[3] : 0;
+
+    const cl_ulong nb00 = src0 ? src0->nb[0] : 0;
+    const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
+    const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
+    const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
+
+    const int ne10 = src1 ? src1->ne[0] : 0;
+    const int ne11 = src1 ? src1->ne[1] : 0;
+    const int ne12 = src1 ? src1->ne[2] : 0;
+    const int ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
+
+    const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
+    const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
+    const cl_ulong nb12 = src1 ? src1->nb[2] : 0;
+    const cl_ulong nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
+
+    const int ne0  = dst ? dst->ne[0] : 0;
+    const int ne1  = dst ? dst->ne[1] : 0;
+    const int ne2  = dst ? dst->ne[2] : 0;
+    const int ne3  = dst ? dst->ne[3] : 0;
+
+    const cl_ulong nb0  = dst ? dst->nb[0] : 0;
+    const cl_ulong nb1  = dst ? dst->nb[1] : 0;
+    const cl_ulong nb2  = dst ? dst->nb[2] : 0;
+    const cl_ulong nb3  = dst ? dst->nb[3] : 0;
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offset1 = extra1->offset + src1->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    bool bcast_row = false;
+    cl_kernel kernel;
+
+    if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
+        GGML_ASSERT(ggml_is_contiguous(src0));
+
+        // src1 is a row
+        GGML_ASSERT(ne11 == 1);
+
+        bcast_row = true;
+        int ne = ne00 / 4;
+        kernel = backend_ctx->kernel_mul_row;
+
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extra1->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
+        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
+        CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &ne));
+    } else {
+        kernel = backend_ctx->kernel_mul;
+
+        CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+        CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+        CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+        CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+        CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+        CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+        CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+        CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+        CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne03));
+        CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00));
+        CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01));
+        CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02));
+        CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03));
+        CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne10));
+        CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne11));
+        CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &ne12));
+        CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &ne13));
+        CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10));
+        CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11));
+        CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12));
+        CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13));
+        CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &ne0));
+        CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),      &ne1));
+        CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int),      &ne2));
+        CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int),      &ne3));
+        CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0));
+        CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1));
+        CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2));
+        CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3));
+    }
+
+    if (bcast_row) {
+        int n = ggml_nelements(dst)/4;
+        size_t global_work_size[] = {(size_t)n, 1, 1};
+        size_t local_work_size[] = {64, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+        g_profiling_info.emplace_back();
+        populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+    } else {
+        unsigned int nth = MIN(64, ne0);
+        size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03};
+        size_t local_work_size[] = {nth, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+        g_profiling_info.emplace_back();
+        populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+    }
+}
+
+static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    UNUSED(src1);
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    cl_kernel kernel;
+
+    int n = ggml_nelements(dst);
+
+    if (n % 4 == 0) {
+        kernel = backend_ctx->kernel_gelu_4;
+        n /= 4;
+    } else {
+        kernel = backend_ctx->kernel_gelu;
+    }
+
+    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
+
+    size_t global_work_size[] = {(size_t)n, 1, 1};
+    size_t local_work_size[] = {64, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+    cl_event evt;
+    clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt);
+
+    g_profiling_info.emplace_back();
+    populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+    clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL);
+#endif
+}
+
+static void ggml_cl_silu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    UNUSED(src1);
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    cl_kernel kernel;
+
+    int n = ggml_nelements(dst);
+
+    if (n % 4 == 0) {
+        kernel = backend_ctx->kernel_silu_4;
+        n /= 4;
+    } else {
+        kernel = backend_ctx->kernel_silu;
+    }
+
+    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
+
+    size_t global_work_size[] = {(size_t)n, 1, 1};
+    size_t local_work_size[] = {64, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+    cl_event evt;
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+    g_profiling_info.emplace_back();
+    populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+}
+
+static void ggml_cl_relu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    UNUSED(src1);
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    cl_kernel kernel = backend_ctx->kernel_relu;
+
+    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
+
+    const int64_t n = ggml_nelements(dst);
+
+    size_t global_work_size[] = {(size_t)n, 1, 1};
+    size_t local_work_size[] = {64, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+    cl_event evt;
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+    g_profiling_info.emplace_back();
+    populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+}
+
+static void ggml_cl_clamp(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    UNUSED(src1);
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    float min;
+    float max;
+    memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
+    memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
+
+    cl_kernel kernel = backend_ctx->kernel_clamp;
+
+    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float),    &min));
+    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float),    &max));
+
+    const int64_t n = ggml_nelements(dst);
+
+    size_t global_work_size[] = {(size_t)n, 1, 1};
+    size_t local_work_size[] = {64, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+    cl_event evt;
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+    g_profiling_info.emplace_back();
+    populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+}
+
+static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    UNUSED(src1);
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    const int ne00 = src0 ? src0->ne[0] : 0;
+    const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
+
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+
+    const int nth = MIN(64, ne00);
+
+    cl_kernel kernel = backend_ctx->kernel_norm;
+
+    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),    &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong),  &offset0));
+    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),    &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong),  &offsetd));
+    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),       &ne00));
+    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong),  &nb01));
+    CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float),     &eps));
+    CL_CHECK(clSetKernelArg(kernel, 7, sizeof(float)*nth, NULL));
+
+    const int64_t nrows = ggml_nrows(src0);
+
+    size_t global_work_size[] = {(size_t)nrows*nth, 1, 1};
+    size_t local_work_size[] = {(size_t)nth, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+    cl_event evt;
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+    g_profiling_info.emplace_back();
+    populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+}
+
+static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    UNUSED(src1);
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_backend_opencl_device_context * dev_ctx =
+        (ggml_backend_opencl_device_context *)backend->device->context;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    const int ne00 = src0 ? src0->ne[0] : 0;
+    const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
+
+    GGML_ASSERT(ne00 % 4 == 0);
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+
+    const int nth = MIN(64, ne00);
+
+    const int64_t nrows = ggml_nrows(src0);
+
+    size_t global_work_size[] = {(size_t)nrows*nth, 1, 1};
+    size_t local_work_size[] = {(size_t)nth, 1, 1};
+
+    cl_kernel kernel = backend_ctx->kernel_rms_norm;
+
+    // Note, this kernel declares local memory in kernel args and the size
+    // depends on subgroup size.
+    // Retrieve subgroup size.
+    // Note, this requires OpenCL 2.1 and above
+    size_t sgs;
+    CL_CHECK(clGetKernelSubGroupInfo(kernel, dev_ctx->device,
+        CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE,
+        sizeof(local_work_size), local_work_size,
+        sizeof(size_t), &sgs, NULL));
+
+    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),    &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong),  &offset0));
+    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),    &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong),  &offsetd));
+    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),       &ne00));
+    CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong),  &nb01));
+    CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float),     &eps));
+    // This is local memory - the size depends on subgroup size.
+    CL_CHECK(clSetKernelArg(kernel, 7, sizeof(float)*nth/sgs,  NULL));
+
+#ifdef GGML_OPENCL_PROFILING
+    cl_event evt;
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+    g_profiling_info.emplace_back();
+    populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+}
+
+static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(src1);
+    GGML_ASSERT(src1->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
+    const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offset1 = extra1->offset + src1->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+#ifdef GGML_OPENCL_SOA_Q
+    ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra;
+#endif
+
+    const int  ne00 = src0 ? src0->ne[0] : 0;
+    const int  ne01 = src0 ? src0->ne[1] : 0;
+    const int  ne02 = src0 ? src0->ne[2] : 0;
+    const int  ne03 = src0 ? src0->ne[3] : 0;
+
+    const cl_ulong nb00 = src0 ? src0->nb[0] : 0;
+    const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
+    const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
+    const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
+
+    const int  ne10 = src1 ? src1->ne[0] : 0;
+    const int  ne11 = src1 ? src1->ne[1] : 0;
+    const int  ne12 = src1 ? src1->ne[2] : 0;
+    const int  ne13 = src1 ? src1->ne[3] : 0;
+
+    const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
+    const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
+    const cl_ulong nb12 = src1 ? src1->nb[2] : 0;
+    const cl_ulong nb13 = src1 ? src1->nb[3] : 0;
+
+    const int  ne0 = dst ? dst->ne[0] : 0;
+    const int  ne1 = dst ? dst->ne[1] : 0;
+
+    int r2 = ne12/ne02;
+    int r3 = ne13/ne03;
+
+    GGML_ASSERT(ne00 == ne10);
+
+    int nth0 = 32;
+    int nth1 = 1;
+    int nrows = 1;
+    // The number of values produced by each subgroup
+    int ndst = 4;
+
+    cl_kernel kernel;
+
+#ifdef GGML_OPENCL_USE_ADRENO_KERNELS
+    cl_context context = backend_ctx->context;
+
+    if (ne01 && ne1 && use_adreno_kernels(src0)) {
+
+    // init CL objects
+    // <--------------------------------------------> //
+    cl_int              status;
+    cl_image_format     img_fmt_1d;
+    cl_image_desc       img_desc_1d;
+    cl_buffer_region    region;
+    cl_mem              A_image1d = nullptr;
+    cl_mem              B_image1d = nullptr;
+    cl_mem              B_sub_buffer = nullptr;
+    cl_mem              C_d = nullptr;
+    // for B transpose
+    cl_mem B_d = nullptr;
+    cl_mem B_d_input_image = nullptr;
+    // <--------------------------------------------> //
+
+    // define matrix dimensions
+    // <--------------------------------------------> //
+    int M = ne01;
+    int N = ne1;
+    int K = ne00;
+    int padding;
+    // <--------------------------------------------> //
+
+    // q4_0 x fp32
+    if(src0t == GGML_TYPE_Q4_0 && src1t == GGML_TYPE_F32) {
+        // TODO: remove duplicate definitions of image description + format -- move to top
+
+        // create an image for A
+        // <--------------------------------------------> //
+        if (N == 1) {
+            img_fmt_1d = { CL_R, CL_UNSIGNED_INT32};
+        } else {
+            img_fmt_1d = { CL_R, CL_FLOAT};
+        }
+        memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+        img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+        img_desc_1d.image_width = M * K / 2 / 4;    // Divide by 4 for char -> float
+        img_desc_1d.buffer = extra0_q4_0->q;
+        A_image1d = clCreateImage(
+            context,
+            CL_MEM_READ_ONLY,
+            &img_fmt_1d,
+            &img_desc_1d,
+            NULL,
+            &status);
+        CL_CHECK(status);
+        // <--------------------------------------------> //
+
+
+        // create a sub_buffer for B
+        // <--------------------------------------------> //
+        region.origin = (extra1->offset);
+        region.size = K * N * sizeof(float);
+        B_sub_buffer = clCreateSubBuffer(
+            extra1->data_device,
+            0,
+            CL_BUFFER_CREATE_TYPE_REGION,
+            ®ion,
+            &status);
+        CL_CHECK(status);
+        // <--------------------------------------------> //
+
+        // transpose activation for Skyler's gemm
+        if (N != 1) {
+            //how many extra elements beyond multiple of 8
+            int extra_elements = N % 8;
+
+            //how much padding to add
+            padding = 0;
+            if (extra_elements > 0){
+                padding = 8 - extra_elements;
+            }
+
+            // Specify the starting offset (in bytes)
+            region.origin = 0;
+            // Specify the size of the sub-buffer (divide by 2 for FP16)
+            region.size = K * (N + padding) * sizeof(float)/2;
+            B_d = clCreateSubBuffer(
+                backend_ctx->B_d_max,
+                0,
+                CL_BUFFER_CREATE_TYPE_REGION,
+                ®ion,
+                &status);
+            CL_CHECK(status);
+
+            cl_image_format image_format_B_d_input = { CL_RGBA, CL_FLOAT };
+            cl_image_desc image_desc_B_d_input = {
+                CL_MEM_OBJECT_IMAGE1D_BUFFER,
+                static_cast(K * N / 4),
+                0, 0, 0, 0, 0, 0, 0, { B_sub_buffer }
+            };
+            B_d_input_image = clCreateImage(
+                context,
+                0,
+                &image_format_B_d_input,
+                &image_desc_B_d_input,
+                NULL,
+                &status);
+            CL_CHECK(status);
+
+            cl_image_format image_format_B_d_output = { CL_RGBA, CL_HALF_FLOAT }; //(CL_HALF_FLOAT for FP16)
+            cl_image_desc image_desc_B_d_output = {
+                CL_MEM_OBJECT_IMAGE1D_BUFFER,
+                static_cast(K * (N + padding)/4),
+                0, 0, 0, 0, 0, 0, 0, { B_d }
+            };
+            B_image1d = clCreateImage(
+                context,
+                0,
+                &image_format_B_d_output,
+                &image_desc_B_d_output,
+                NULL,
+                &status);
+            CL_CHECK(status);
+
+            int height_B = N/4;
+            int width_B = K/4;
+            int padded_height_B = (N + padding)/4;
+
+            kernel = backend_ctx->kernel_transpose_32_16;
+            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &B_d_input_image));
+            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &B_image1d));
+            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int),    &height_B));
+            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int),    &width_B));
+            CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),    &padded_height_B));
+
+            size_t local_size_t[2] = { 1, 16 };
+            //WGS tuning
+            if (ne0 == 4096 && ne1 == 128 && ne10 == 4096) {
+                local_size_t[0]=4;
+                local_size_t[1]=8;
+            } else if (ne0 == 11008 && ne1 == 128 && ne10 == 4096) {
+                local_size_t[0]=2;
+                local_size_t[1]=8;
+            } else if(ne0 == 4096 && ne1 == 128 && ne10 == 11008) {
+                local_size_t[0]=1;
+                local_size_t[1]=8;
+            } else if(ne0 == 32000 && ne1 == 128 && ne10 == 4096) {
+                local_size_t[0]=2;
+                local_size_t[1]=8;
+            }
+
+            size_t global_size_t[2] = {
+                static_cast(width_B),
+                static_cast(padded_height_B)
+            };
+
+            #ifdef GGML_OPENCL_PROFILING
+                cl_event evt;
+                CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 2, NULL, global_size_t, local_size_t, 0, NULL, &evt));
+
+                g_profiling_info.emplace_back();
+                populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_size_t, local_size_t, dst);
+            #else
+                CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 2, NULL, global_size_t, local_size_t, 0, NULL, NULL));
+            #endif
+        } else {
+            // no need to transpose B in other cases
+            // create an image for B from sub_buffer
+            // <--------------------------------------------> //
+            img_fmt_1d = {CL_RGBA, CL_FLOAT};
+
+            memset(&img_desc_1d, 0, sizeof(img_desc_1d));
+            img_desc_1d.image_width = K * N / 4;
+            img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER;
+            img_desc_1d.buffer = B_sub_buffer;
+            B_image1d = clCreateImage(
+                context,
+                CL_MEM_READ_ONLY,
+                &img_fmt_1d,
+                &img_desc_1d,
+                NULL,
+                &status);
+            CL_CHECK(status);
+            // <--------------------------------------------> //
+        }
+
+        // choose gemm or gemv kernel
+        // <--------------------------------------------> //
+        if (N == 1) {
+            kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general;
+            if (M == 4096 && K == 4096) {
+                kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096;
+            } else if (M == 4096 && K == 11008) {
+                kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008;
+            } else if (M == 11008 && K == 4096) {
+                kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096;
+            } else if (M == 32000 && K == 4096) {
+                kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096;
+            }
+        } else {
+            kernel = backend_ctx->CL_mul_mat_Ab_Bi_8x4;
+        }
+        // <--------------------------------------------> //
+
+        // set kernel args
+        // <--------------------------------------------> //
+        cl_uint k_arg = 0;
+
+        if (N == 1) {
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem),   &A_image1d));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem),   &extra0_q4_0->d));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem),   &B_image1d));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_ulong), &extra1->offset));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(cl_ulong), &extrad->offset));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne00));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne01));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne02));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne10));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne12));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne0));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &ne1));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &r2));
+            CL_CHECK(clSetKernelArg(kernel,  k_arg++, sizeof(int),      &r3));
+        } else {
+            region.origin = extrad->offset; // Specify the starting offset (in bytes)
+            region.size = M * N * sizeof(float); // Specify the size of the sub-buffer
+            C_d = clCreateSubBuffer(extrad->data_device, CL_MEM_WRITE_ONLY, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status);
+            CL_CHECK(status);
+
+            int padded_N = ne1 + padding;
+
+            CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); //A_q_dextra0_q4_0->q
+            CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); //A_s_d
+            CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &B_image1d)); //B_d
+            CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &C_d)); //C_d
+            CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),    &ne01)); //M
+            CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int),    &padded_N)); //N with padding
+            CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),    &ne00)); //K
+            CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int),    &ne1)); //N without padding
+        }
+        // <--------------------------------------------> //
+
+        // choose workgroup size
+        // <--------------------------------------------> //
+        size_t global_work_size[3] = {
+            64, static_cast((M+63)/64), static_cast((N+31)/32)};
+        size_t local_work_size[3] = {64, 2, 4};
+
+        global_work_size[0] = (size_t)(ceil((float)ne1/8));
+        global_work_size[1] = (size_t)(ne01/4);
+        global_work_size[2] = (size_t)(1);
+
+        local_work_size[0]  = (size_t)(1); //4x32 for FP32
+        local_work_size[1]  = (size_t)(128);
+        local_work_size[2]  = (size_t)(1);
+
+        //WGS tuning
+        if (ne0 == 4096 && ne1 == 128 && ne10 == 4096) {
+            local_work_size[0] = 1;
+            local_work_size[1] = 128;
+        } else if (ne0 == 11008 && ne1 == 128 && ne10 == 4096) {
+            local_work_size[0] = 2;
+            local_work_size[1] = 64;
+        } else if (ne0 == 4096 && ne1 == 128 && ne10 == 11008) {
+            local_work_size[0] = 2;
+            local_work_size[1] = 64;
+        } else if (ne0 == 32000 && ne1 == 128 && ne10 == 4096) {
+            local_work_size[0] = 2;
+            local_work_size[1] = 64;
+        }
+
+        if (N == 1) {
+            local_work_size[0] = backend_ctx->adreno_wave_size; // localsize
+            local_work_size[1] = 4; // reduce factor
+            local_work_size[2] = 1;
+
+            global_work_size[0] = M / 2;
+            global_work_size[1] = 4; // reduce factor
+            global_work_size[2] = 1;
+        }
+        // <--------------------------------------------> //
+
+        // enqueue kernel with profiling
+        // <--------------------------------------------> //
+    #ifdef GGML_OPENCL_PROFILING
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+        g_profiling_info.emplace_back();
+        populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+        // enqueue kernel without profiling
+    #else
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+    #endif
+        // <--------------------------------------------> //
+
+        // deallocate sub buffers and images
+        // <--------------------------------------------> //
+        CL_CHECK(clReleaseMemObject(A_image1d));
+        CL_CHECK(clReleaseMemObject(B_sub_buffer));
+        CL_CHECK(clReleaseMemObject(B_image1d));
+
+        if (N != 1) {
+            CL_CHECK(clReleaseMemObject(B_d));
+            CL_CHECK(clReleaseMemObject(B_d_input_image));
+            CL_CHECK(clReleaseMemObject(C_d));
+        }
+        // <--------------------------------------------> //
+
+        return;
+    }
+    } // if (ne01 && ne1)
+#endif // GGML_OPENCL_USE_ADRENO_KERNELS
+
+    if (!ggml_is_transposed(src0) &&
+        !ggml_is_transposed(src1) &&
+        src1t == GGML_TYPE_F32 &&
+        ne00%32 == 0 &&
+        ne11 > 2) {
+#ifdef GGML_OPENCL_SOA_Q
+        // Set up kernel.
+        switch(src0t) {
+            case GGML_TYPE_Q4_0:
+                // This should have been satisfied.
+                GGML_ASSERT(ne11 == ne1);
+                GGML_ASSERT(ne01 == ne0);
+
+                if (backend_ctx->gpu_family == INTEL) {
+                    nth0 = 16;
+                    nth1 = 1;
+
+                    kernel = backend_ctx->kernel_mul_mat_q4_0_f32_1d_16x_flat;
+                } else if (backend_ctx->gpu_family == ADRENO) {
+                    nth0 = 64;
+                    nth1 = 1;
+
+                    kernel = backend_ctx->kernel_mul_mat_q4_0_f32_1d_8x_flat;
+                } else {
+                    GGML_ASSERT(false && "TODO: Unknown GPU");
+                }
+
+                CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0_q4_0->q));
+                CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_q4_0->d));
+                CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+                CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+                CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+                CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+                CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+                CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+                CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne10));
+                CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));
+                CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne0));
+                CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne1));
+                CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &r2));
+                CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &r3));
+                break;
+            default:
+                break;
+        }
+
+        // Launch kernel.
+        if (src0t == GGML_TYPE_Q4_0) {
+            size_t global_work_size[] = {(size_t)(ne01 + 7)/8*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};
+            size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};
+
+            if (backend_ctx->gpu_family == INTEL) {
+                // Set global size for Intel. It uses 16x output values.
+                global_work_size[0] = (size_t)(ne01 + 15)/16*nth0;
+                global_work_size[1] = (size_t)ne11*nth1;
+                global_work_size[2] = (size_t)ne12*ne13;
+            }
+
+#ifdef GGML_OPENCL_PROFILING
+            cl_event evt;
+            CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+            g_profiling_info.emplace_back();
+            populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+            CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+            return;
+        }
+#else // GGML_OPENCL_SOA_Q
+        // TODO: add block_q4_0 variant.
+#endif // GGML_OPENCL_SOA_Q
+    }
+
+    // use custom matrix x vector kernel
+    switch (src0t) {
+        case GGML_TYPE_F32:
+            //GGML_ASSERT(ne02 == ne12);
+            GGML_ASSERT(src1t == GGML_TYPE_F32);
+            kernel = backend_ctx->kernel_mul_mat_f32_f32;
+            nrows = 4;
+
+            if (backend_ctx->gpu_family == INTEL) {
+                nth0 = 32;
+                nth1 = 1;
+            } else if (backend_ctx->gpu_family == ADRENO) {
+                nth0 = 64;
+                nth1 = 1;
+            } else {
+                GGML_ASSERT(false && "TODO: Unknown GPU");
+            }
+
+            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb00));
+            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01));
+            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02));
+            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03));
+            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne10));
+            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne11));
+            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne12));
+            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb10));
+            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11));
+            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12));
+            CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13));
+            CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int),      &ne0));
+            CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int),      &ne1));
+            CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &r2));
+            CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),      &r3));
+            break;
+        case GGML_TYPE_F16:
+            //GGML_ASSERT(ne02 == ne12);
+            if (backend_ctx->gpu_family == INTEL) {
+                nth0 = 32;
+                nth1 = 1;
+            } else if (backend_ctx->gpu_family == ADRENO) {
+                nth0 = 64;
+                nth1 = 1;
+            } else {
+                GGML_ASSERT(false && "TODO: Unknown GPU");
+            }
+
+            if (src1t == GGML_TYPE_F32) {
+                if (ne11 * ne12 < 4) {
+                    kernel = backend_ctx->kernel_mul_mat_f16_f32_1row;
+                } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
+                    kernel = backend_ctx->kernel_mul_mat_f16_f32_l4;
+                    nrows = ne11;
+                } else {
+                    kernel = backend_ctx->kernel_mul_mat_f16_f32;
+                    nrows = 4;
+                }
+            } else {
+                kernel = backend_ctx->kernel_mul_mat_f16_f16;
+                nrows = 4;
+            }
+
+            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb00));
+            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01));
+            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02));
+            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03));
+            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne10));
+            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne11));
+            CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne12));
+            CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb10));
+            CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11));
+            CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12));
+            CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13));
+            CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int),      &ne0));
+            CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int),      &ne1));
+            CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int),      &r2));
+            CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int),      &r3));
+            break;
+        case GGML_TYPE_Q4_0:
+            // This should have been satisfied.
+            GGML_ASSERT(ne11 == ne1);
+            GGML_ASSERT(ne01 == ne0);
+
+#ifdef GGML_OPENCL_SOA_Q
+            if (backend_ctx->gpu_family == INTEL) {
+                nth0 = 16;
+                nth1 = 1;
+
+                kernel = backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat;
+                ndst = 8;
+            } else if (backend_ctx->gpu_family == ADRENO) {
+                nth0 = 64;
+                nth1 = 1;
+
+                kernel = backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat;
+                ndst =8;
+            } else {
+                GGML_ASSERT(false && "TODO: Unknown GPU");
+            }
+
+            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0_q4_0->q));
+            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_mem),   &extra0_q4_0->d));
+            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne10));
+            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));
+            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne0));
+            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne1));
+            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &r2));
+            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &r3));
+#else // GGML_OPENCL_SOA_Q
+            if (backend_ctx->gpu_family == INTEL) {
+                // Use 1D local size. Each workgroup is a SIMD group. Each SIMD
+                // group produces N_DST (4 for Q4_0 kernel) values in the result.
+                // The number of workgroups on dim 0 (the leading dimension) is
+                // the nearest multiple of 4 that covers ne0 (equals ne01).
+                nth0 = 16;
+                nth1 = 1;
+
+                kernel = backend_ctx->kernel_mul_mat_q4_0_f32;
+                ndst = 4;
+            } else if (backend_ctx->gpu_family == ADRENO) {
+                nth0 = 64;
+                nth1 = 1;
+
+                kernel = backend_ctx->kernel_mul_mat_q4_0_f32_v;
+                ndst = 4;
+            } else {
+                GGML_ASSERT(false && "TODO: Unknown GPU");
+            }
+
+            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne10));
+            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));
+            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne0));
+            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne1));
+            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &r2));
+            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &r3));
+#endif // GGML_OPENCL_SOA_Q
+            break;
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+            kernel = backend_ctx->kernel_mul_mv_q6_K_f32;
+
+            if (backend_ctx->gpu_family == INTEL) {
+                nth0 = 2;
+                nth1 = 16;
+            } else if (backend_ctx->gpu_family == ADRENO) {
+                nth0 = 2;
+                nth1 = 64;
+            } else {
+                GGML_ASSERT(false && "TODO: Unknown GPU");
+            }
+
+            CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+            CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+            CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+            CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+            CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+            CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+            CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+            CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne10));
+            CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne12));
+            CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne0));
+            CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne1));
+            CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &r2));
+            CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &r3));
+            break;
+        default:
+            GGML_ASSERT(false && "not implemented");
+    }
+
+    if (src0t == GGML_TYPE_Q4_0 ||
+        src0t == GGML_TYPE_Q4_1 ||
+        src0t == GGML_TYPE_Q8_0 ||
+        src0t == GGML_TYPE_Q2_K) {
+        // Each SIMD group produces N_DST values in the result. Assuming each
+        // workgroup has N_SIMDGROUP SIMD groups, then each workgroup will
+        // produce N_DST*N_SIMDGROUP values in the result. Hence, the grid size
+        // (number of workgroups) will be a nearest multiple of
+        // N_DST*N_SIMDGROUP to cover the size of the dimension. Below, 4 is
+        // N_DST*N_SIMDGROUP (see the kernel for Q4_0 matmul).
+        size_t global_work_size[] = {(size_t)(ne01 + ndst-1)/ndst*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};
+        size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+        g_profiling_info.emplace_back();
+        populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+    } else if (src0t == GGML_TYPE_Q4_K) {
+        GGML_ASSERT(false && "not implemented");
+    } else if (src0t == GGML_TYPE_Q3_K) {
+        GGML_ASSERT(false && "not implemented");
+    } else if (src0t == GGML_TYPE_Q5_K) {
+        GGML_ASSERT(false && "not implemented");
+    } else if (src0t == GGML_TYPE_Q6_K) {
+        size_t global_work_size[] = {(size_t)(ne01+1)/2*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};
+        size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+        g_profiling_info.emplace_back();
+        populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+    } else {
+        int64_t ny = (ne11 + nrows - 1)/nrows;
+
+        size_t global_work_size[] = {(size_t)ne01*nth0, (size_t)ny*nth1, (size_t)ne12*ne13};
+        size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+        g_profiling_info.emplace_back();
+        populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+    }
+}
+
+static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+    GGML_UNUSED(src1);
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    float scale;
+    memcpy(&scale, dst->op_params, sizeof(scale));
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    cl_kernel kernel = backend_ctx->kernel_scale;
+
+    CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float),    &scale));
+
+    int n = ggml_nelements(dst)/4;
+
+    size_t global_work_size[] = {(size_t)n, 1, 1};
+    size_t local_work_size[] = {64, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+    cl_event evt;
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+    g_profiling_info.emplace_back();
+    populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+}
+
+static void ggml_cl_cpy(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(src1);
+    GGML_ASSERT(src1->extra);
+
+    // GGML_OP_CPY happens between src0 and src1.
+    // GGML_OP_DUP and GGML_OP_CONT happen between src0 and dst.
+    UNUSED(dst);
+
+    const int ne00 = src0 ? src0->ne[0] : 0;
+    const int ne01 = src0 ? src0->ne[1] : 0;
+    const int ne02 = src0 ? src0->ne[2] : 0;
+    const int ne03 = src0 ? src0->ne[3] : 0;
+
+    const cl_ulong nb00 = src0 ? src0->nb[0] : 0;
+    const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
+    const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
+    const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
+
+    const int ne10 = src1 ? src1->ne[0] : 0;
+    const int ne11 = src1 ? src1->ne[1] : 0;
+    const int ne12 = src1 ? src1->ne[2] : 0;
+    const int ne13 = src1 ? src1->ne[3] : 0;
+
+    const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
+    const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
+    const cl_ulong nb12 = src1 ? src1->nb[2] : 0;
+    const cl_ulong nb13 = src1 ? src1->nb[3] : 0;
+
+    const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
+    const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offset1 = extra1->offset + src1->view_offs;
+
+    cl_kernel kernel;
+
+    switch (src0t) {
+        case GGML_TYPE_F32:
+            switch (src1t) {
+                case GGML_TYPE_F16:
+                    kernel = backend_ctx->kernel_cpy_f32_f16;
+                    break;
+                case GGML_TYPE_F32:
+                    kernel = backend_ctx->kernel_cpy_f32_f32;
+                    break;
+                default:
+                    GGML_ASSERT(false && "not implemented");
+            }
+            break;
+        case GGML_TYPE_F16:
+            switch (src1t) {
+                case GGML_TYPE_F16:
+                    kernel = backend_ctx->kernel_cpy_f16_f16;
+                    break;
+                case GGML_TYPE_F32:
+                    kernel = backend_ctx->kernel_cpy_f16_f32;
+                    break;
+                default:
+                    GGML_ASSERT(false && "not implemented");
+            }
+            break;
+        default:
+            GGML_ASSERT(false && "not implemented");
+    }
+
+    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(int),      &ne00));
+    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(int),      &ne01));
+    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne02));
+    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne03));
+    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(cl_ulong), &nb00));
+    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(cl_ulong), &nb01));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int),      &ne10));
+    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &ne11));
+    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int),      &ne12));
+    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int),      &ne13));
+    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb10));
+    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11));
+    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12));
+    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13));
+
+    const int nth = MIN(64, ne00);
+
+    size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
+    size_t local_work_size[] = {(size_t)nth, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+    cl_event evt;
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+    g_profiling_info.emplace_back();
+    populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, src1);
+#else
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+}
+
+static void ggml_cl_dup(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cl_cpy(backend, src0, dst, nullptr);
+    UNUSED(src1);
+}
+
+static void ggml_cl_diag_mask_inf(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    UNUSED(src1);
+
+    int n_past = ((int32_t *)(dst->op_params))[0];
+
+    const int  ne00 = src0 ? src0->ne[0] : 0;
+    const int  ne01 = src0 ? src0->ne[1] : 0;
+    const int  ne02 = src0 ? src0->ne[2] : 0;
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    cl_kernel kernel;
+
+    if (ne00%8 == 0) {
+        kernel = backend_ctx->kernel_diag_mask_inf_8;
+
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
+        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),      &ne00));
+        CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int),      &ne01));
+        CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &n_past));
+
+        size_t global_work_size[] = {(size_t)ne00*ne01*ne02/8, 1, 1};
+        size_t local_work_size[] = {64, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+        g_profiling_info.emplace_back();
+        populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+    } else {
+        kernel = backend_ctx->kernel_diag_mask_inf;
+
+        CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem),   &extra0->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
+        CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem),   &extrad->data_device));
+        CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
+        CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int),      &ne00));
+        CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int),      &ne01));
+        CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int),      &n_past));
+
+        size_t global_work_size[] = {(size_t)ne00, (size_t)ne01, (size_t)ne02};
+        size_t local_work_size[] = {64, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+        cl_event evt;
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+        g_profiling_info.emplace_back();
+        populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+        CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+    }
+}
+
+static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    // Softmax can now fuse KQ mask and KQ scale, which used to be two additional
+    // ops before softmax. It now also fuses alibi if `max_bias > 0`. For llama,
+    // alibi is not used; however, for some other models, it is used.
+    // KQ_mask
+    if (src1) {
+        GGML_ASSERT(src1);
+        GGML_ASSERT(src1->extra);
+    }
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    ggml_tensor_extra_cl * extra1 = src1 ? (ggml_tensor_extra_cl *)src1->extra : nullptr;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
+
+    const int  ne00 = src0 ? src0->ne[0] : 0;
+    const int  ne01 = src0 ? src0->ne[1] : 0;
+    const int  ne02 = src0 ? src0->ne[2] : 0;
+    const int  ne03 = src0 ? src0->ne[3] : 0;
+
+    float scale, max_bias;
+    memcpy(&scale,    dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, dst->op_params + 1, sizeof(float));
+
+    const int nrows_x = ggml_nrows(src0);
+    const int nrows_y = src0->ne[1];
+
+    const int n_head      = nrows_x/nrows_y;
+    const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    // Local size must be wave size. Each workgroup is a wave, working on a row,
+    // where a row corresponds to leading dimension.
+    int nth = MIN(32, ne00);
+
+    if (backend_ctx->gpu_family == INTEL) {
+        // This is the same as the initial value.
+        nth = MIN(32, ne00);
+    }
+    else if (backend_ctx->gpu_family == ADRENO) {
+        nth = 64;
+    } else {
+        GGML_ASSERT(false && "TODO: Unknown GPU");
+    }
+
+    cl_kernel kernel;
+
+    if (ne00%4 == 0) {
+        kernel = backend_ctx->kernel_soft_max_4;
+    } else {
+        kernel = backend_ctx->kernel_soft_max;
+    }
+
+    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   extra1 ? &extra1->data_device : &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(int),      &ne00));
+    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(int),      &ne01));
+    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne02));
+    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(float),    &scale));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float),    &max_bias));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float),    &m0));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float),    &m1));
+    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int),      &n_head_log2));
+
+    size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
+    size_t local_work_size[] = {(size_t)nth, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+    cl_event evt;
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+    g_profiling_info.emplace_back();
+    populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+}
+
+static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0);
+    GGML_ASSERT(src0->extra);
+    GGML_ASSERT(src1);
+    GGML_ASSERT(src1->extra);
+    GGML_ASSERT(dst);
+    GGML_ASSERT(dst->extra);
+
+    ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
+    cl_command_queue queue = backend_ctx->queue;
+
+    ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
+    ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
+    ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
+
+    cl_ulong offset0 = extra0->offset + src0->view_offs;
+    cl_ulong offset1 = extra1->offset + src1->view_offs;
+    cl_ulong offsetd = extrad->offset + dst->view_offs;
+
+    ggml_tensor * src2 = dst->src[2];
+    ggml_tensor_extra_cl * extra2 = src2 ? (ggml_tensor_extra_cl *)src2->extra : nullptr;
+
+    cl_ulong offset2 = extra2 ? extra2->offset + src2->view_offs : offset0;
+
+    const int  ne00 = src0 ? src0->ne[0] : 0;
+    const int  ne01 = src0 ? src0->ne[1] : 0;
+    const int  ne02 = src0 ? src0->ne[2] : 0;
+    const int  ne03 = src0 ? src0->ne[3] : 0;
+
+    const int  nb00 = src0 ? src0->nb[0] : 0;
+    const int  nb01 = src0 ? src0->nb[1] : 0;
+    const int  nb02 = src0 ? src0->nb[2] : 0;
+    const int  nb03 = src0 ? src0->nb[3] : 0;
+
+    const int ne10 = src1 ? src1->ne[0] : 0;
+    const int ne11 = src1 ? src1->ne[1] : 0; UNUSED(ne11);
+    const int ne12 = src1 ? src1->ne[2] : 0; UNUSED(ne12);
+    const int ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
+
+    const int  ne0 = dst ? dst->ne[0] : 0;
+    const int  ne1 = dst ? dst->ne[1] : 0;
+    const int  ne2 = dst ? dst->ne[2] : 0;
+    const int  ne3 = dst ? dst->ne[3] : 0;
+
+    const int  nb0 = dst ? dst->nb[0] : 0;
+    const int  nb1 = dst ? dst->nb[1] : 0;
+    const int  nb2 = dst ? dst->nb[2] : 0;
+    const int  nb3 = dst ? dst->nb[3] : 0;
+
+    GGML_ASSERT(ne10 == ne02);
+
+    int nth = MIN(64, ne00);
+
+    const int n_past     = ((int *) dst->op_params)[0];
+    const int n_dims     = ((int *) dst->op_params)[1];
+    const int mode       = ((int *) dst->op_params)[2];
+    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+
+    float freq_base;
+    float freq_scale;
+    float ext_factor;
+    float attn_factor;
+    float beta_fast;
+    float beta_slow;
+
+    memcpy(&freq_base,   (int32_t *) dst->op_params + 5, sizeof(float));
+    memcpy(&freq_scale,  (int32_t *) dst->op_params + 6, sizeof(float));
+    memcpy(&ext_factor,  (int32_t *) dst->op_params + 7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
+    memcpy(&beta_fast,   (int32_t *) dst->op_params + 9, sizeof(float));
+    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
+
+    const bool is_neox = mode & 2;
+
+    cl_kernel kernel;
+
+    if (!is_neox) {
+        switch (src0->type) {
+            case GGML_TYPE_F32:
+                kernel = backend_ctx->kernel_rope_norm_f32;
+                break;
+            case GGML_TYPE_F16:
+                kernel = backend_ctx->kernel_rope_norm_f16;
+                break;
+            default:
+                GGML_ASSERT(false);
+        };
+    } else {
+        switch (src0->type) {
+            case GGML_TYPE_F32:
+                kernel = backend_ctx->kernel_rope_neox_f32;
+                break;
+            case GGML_TYPE_F16:
+                kernel = backend_ctx->kernel_rope_neox_f16;
+                break;
+            default:
+                GGML_ASSERT(false);
+        };
+    }
+
+    CL_CHECK(clSetKernelArg(kernel,  0, sizeof(cl_mem),   &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  1, sizeof(cl_ulong), &offset0));
+    CL_CHECK(clSetKernelArg(kernel,  2, sizeof(cl_mem),   &extra1->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  3, sizeof(cl_ulong), &offset1));
+    CL_CHECK(clSetKernelArg(kernel,  4, sizeof(cl_mem),   extra2 ? &extra2->data_device : &extra0->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  5, sizeof(cl_ulong), &offset2));
+    CL_CHECK(clSetKernelArg(kernel,  6, sizeof(cl_mem),   &extrad->data_device));
+    CL_CHECK(clSetKernelArg(kernel,  7, sizeof(cl_ulong), &offsetd));
+    CL_CHECK(clSetKernelArg(kernel,  8, sizeof(int),      &ne00));
+    CL_CHECK(clSetKernelArg(kernel,  9, sizeof(int),      &ne01));
+    CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int),      &ne02));
+    CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int),      &ne03));
+    CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb00));
+    CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb01));
+    CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb02));
+    CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb03));
+    CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int),      &ne0));
+    CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int),      &ne1));
+    CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int),      &ne2));
+    CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int),      &ne3));
+    CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb0));
+    CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb1));
+    CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb2));
+    CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &nb3));
+    CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int),      &n_past));
+    CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int),      &n_dims));
+    CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int),      &n_ctx_orig));
+    CL_CHECK(clSetKernelArg(kernel, 27, sizeof(float),    &freq_base));
+    CL_CHECK(clSetKernelArg(kernel, 28, sizeof(float),    &freq_scale));
+    CL_CHECK(clSetKernelArg(kernel, 29, sizeof(float),    &ext_factor));
+    CL_CHECK(clSetKernelArg(kernel, 30, sizeof(float),    &attn_factor));
+    CL_CHECK(clSetKernelArg(kernel, 31, sizeof(float),    &beta_fast));
+    CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float),    &beta_slow));
+
+    size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
+    size_t local_work_size[] = {(size_t)nth, 1, 1};
+
+#ifdef GGML_OPENCL_PROFILING
+    cl_event evt;
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
+
+    g_profiling_info.emplace_back();
+    populateProfilingInfo(g_profiling_info.back(), evt, kernel, global_work_size, local_work_size, dst);
+#else
+    CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL));
+#endif
+}
+
+//------------------------------------------------------------------------------
+// Op offloading
+//------------------------------------------------------------------------------
+
+typedef void (*ggml_cl_func_t)(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+
+bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor) {
+    ggml_cl_func_t func = nullptr;
+
+    ggml_tensor * src0 = tensor->src[0];
+    ggml_tensor * src1 = tensor->src[1];
+
+    const bool any_on_device = tensor->extra
+        || (src0 != nullptr && src0->extra)
+        || (src1 != nullptr && src1->extra);
+
+    switch (tensor->op) {
+        case GGML_OP_GET_ROWS:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_get_rows;
+            break;
+        case GGML_OP_CPY:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_cpy;
+            break;
+        case GGML_OP_DUP:
+        case GGML_OP_CONT:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_dup;
+            break;
+        case GGML_OP_ADD:
+            if (!any_on_device) {
+                return false;
+            }
+            GGML_ASSERT(ggml_is_contiguous(src0));
+            GGML_ASSERT(ggml_is_contiguous(src1));
+            func = ggml_cl_add;
+            break;
+        case GGML_OP_MUL:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_mul;
+            break;
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(tensor)) {
+                case GGML_UNARY_OP_GELU:
+                    if (!any_on_device) {
+                        return false;
+                    }
+                    func = ggml_cl_gelu;
+                    break;
+                case GGML_UNARY_OP_SILU:
+                    if (!any_on_device) {
+                        return false;
+                    }
+                    func = ggml_cl_silu;
+                    break;
+                case GGML_UNARY_OP_RELU:
+                    if (!any_on_device) {
+                        return false;
+                    }
+                    func = ggml_cl_relu;
+                    break;
+                default:
+                    return false;
+            } break;
+        case GGML_OP_CLAMP:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_clamp;
+            break;
+        case GGML_OP_NORM:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_norm;
+            break;
+        case GGML_OP_RMS_NORM:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_rms_norm;
+            break;
+        case GGML_OP_MUL_MAT:
+            if (!any_on_device && !ggml_cl_can_mul_mat(tensor->src[0], tensor->src[1], tensor)) {
+                return false;
+            }
+            func = ggml_cl_mul_mat;
+            break;
+        case GGML_OP_SCALE:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_scale;
+            break;
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_nop;
+            break;
+        case GGML_OP_DIAG_MASK_INF:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_diag_mask_inf;
+            break;
+        case GGML_OP_SOFT_MAX:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_soft_max;
+            break;
+        case GGML_OP_ROPE:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cl_rope;
+            break;
+        default:
+            return false;
+    }
+
+    func(backend, tensor->src[0], tensor->src[1], tensor);
+    return true;
+}
diff --git a/ggml/src/ggml-opencl/kernels/embed_kernel.py b/ggml/src/ggml-opencl/kernels/embed_kernel.py
new file mode 100644
index 00000000000..b5d1d7242b6
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/embed_kernel.py
@@ -0,0 +1,26 @@
+#
+
+import sys
+import logging
+logger = logging.getLogger("opencl-embed-kernel")
+
+
+def main():
+    logging.basicConfig(level=logging.INFO)
+
+    if len(sys.argv) != 3:
+        logger.info("Usage: python embed_kernel.py  ")
+        sys.exit(1)
+
+    ifile = open(sys.argv[1], "r")
+    ofile = open(sys.argv[2], "w")
+
+    for i in ifile:
+        ofile.write('R"({})"\n'.format(i))
+
+    ifile.close()
+    ofile.close()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl
new file mode 100644
index 00000000000..d1cdf709bab
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl
@@ -0,0 +1,2683 @@
+#ifdef cl_khr_fp16
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#elif defined(cl_amd_fp16)
+#pragma OPENCL EXTENSION cl_amd_fp16 : enable
+#else
+#error "Half precision floating point not supportedby OpenCL implementation on your device."
+#endif
+
+#ifdef cl_khr_subgroups
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+#elif defined(cl_intel_subgroups)
+#pragma OPENCL EXTENSION cl_intel_subgroups : enable
+#else
+#error "Subgroup not supported on your device."
+#endif
+
+#ifdef cl_intel_required_subgroup_size
+// Always use subgroup size of 32 on Intel.
+#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
+#define INTEL_GPU 1
+#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
+#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
+#elif defined(cl_qcom_reqd_sub_group_size)
+// Always use subgroups size of 64 on Adreno.
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define ADRENO_GPU 1
+#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#else
+// TODO: do not know how to choose subgroup size on other GPUs.
+#error "Selecting subgroup size is not supported on your device."
+#endif
+
+#define QK4_0                   32
+#define QR4_0                   2
+#define QK4_1                   32
+#define QR4_1                   2
+#define QK5_0                   32
+#define QR5_0                   2
+#define QK5_1                   32
+#define QR5_1                   2
+#define QK8_0                   32
+#define QR8_0                   1
+#define QK_K                    256
+#define K_QUANTS_PER_ITERATION  2
+
+typedef char int8_t;
+typedef uchar uint8_t;
+typedef short int16_t;
+typedef ushort uint16_t;
+typedef int int32_t;
+typedef uint uint32_t;
+
+//------------------------------------------------------------------------------
+// block_q4_0
+//------------------------------------------------------------------------------
+struct block_q4_0
+{
+    half d;
+    uint8_t qs[QK4_0 / 2];
+};
+
+//------------------------------------------------------------------------------
+// block_q4_1
+//------------------------------------------------------------------------------
+struct block_q4_1
+{
+    half d;
+    half m;
+    uint8_t qs[QK4_1 / 2];
+};
+
+//------------------------------------------------------------------------------
+// block_q5_0
+//------------------------------------------------------------------------------
+struct block_q5_0
+{
+    half d;
+    uint32_t qh;
+    uint8_t qs[QK5_0 / 2];
+};
+
+//------------------------------------------------------------------------------
+// block_q5_1
+//------------------------------------------------------------------------------
+struct block_q5_1
+{
+    half d;
+    half m;
+    uint32_t qh;
+    uint8_t qs[QK5_1 / 2];
+};
+
+//------------------------------------------------------------------------------
+// block_q8_0
+//------------------------------------------------------------------------------
+struct block_q8_0
+{
+    half d;
+    int8_t qs[QK8_0];
+};
+
+//------------------------------------------------------------------------------
+// block_q2_K
+//------------------------------------------------------------------------------
+struct block_q2_K
+{
+    uint8_t scales[16];
+    uint8_t qs[64];
+    half d;
+    half dmin;
+};
+
+//------------------------------------------------------------------------------
+// block_q3_K
+//------------------------------------------------------------------------------
+struct block_q3_K
+{
+    uint8_t hmask[32];
+    uint8_t qs[64];
+    uint8_t scales[12];
+    half d;
+};
+
+//------------------------------------------------------------------------------
+// block_q4_K
+//------------------------------------------------------------------------------
+struct block_q4_K
+{
+    half d;
+    half dmin;
+    uint8_t scales[12];
+    uint8_t qs[128];
+};
+
+//------------------------------------------------------------------------------
+// block_q5_K
+//------------------------------------------------------------------------------
+struct block_q5_K
+{
+    half d;
+    half dmin;
+    uint8_t scales[12];
+    uint8_t qh[32];
+    uint8_t qs[128];
+};
+
+//------------------------------------------------------------------------------
+// block_q6_K
+//------------------------------------------------------------------------------
+struct block_q6_K
+{
+    uint8_t ql[128];
+    uint8_t qh[64];
+    int8_t scales[16];
+    half d;
+};
+
+//------------------------------------------------------------------------------
+// dequantize_q4_0_f32, dequantize_q4_0_f16
+//------------------------------------------------------------------------------
+void dequantize_q4_0_f32(global struct block_q4_0 * xb, short il, float16 * reg) {
+    global ushort * qs = ((global ushort *)xb + 1);
+    float d1 = il ? (xb->d / 16.h) : xb->d;
+    float d2 = d1 / 256.f;
+    float md = -8.h * xb->d;
+    ushort mask0 = il ? 0x00F0 : 0x000F;
+    ushort mask1 = mask0 << 8;
+
+    reg->s0 = d1 * (qs[0] & mask0) + md;
+    reg->s1 = d2 * (qs[0] & mask1) + md;
+
+    reg->s2 = d1 * (qs[1] & mask0) + md;
+    reg->s3 = d2 * (qs[1] & mask1) + md;
+
+    reg->s4 = d1 * (qs[2] & mask0) + md;
+    reg->s5 = d2 * (qs[2] & mask1) + md;
+
+    reg->s6 = d1 * (qs[3] & mask0) + md;
+    reg->s7 = d2 * (qs[3] & mask1) + md;
+
+    reg->s8 = d1 * (qs[4] & mask0) + md;
+    reg->s9 = d2 * (qs[4] & mask1) + md;
+
+    reg->sa = d1 * (qs[5] & mask0) + md;
+    reg->sb = d2 * (qs[5] & mask1) + md;
+
+    reg->sc = d1 * (qs[6] & mask0) + md;
+    reg->sd = d2 * (qs[6] & mask1) + md;
+
+    reg->se = d1 * (qs[7] & mask0) + md;
+    reg->sf = d2 * (qs[7] & mask1) + md;
+}
+
+void dequantize_q4_0_f16(global struct block_q4_0 * xb, short il, half16 * reg) {
+    global ushort * qs = ((global ushort *)xb + 1);
+    half d1 = il ? (xb->d / 16.h) : xb->d;
+    half d2 = d1 / 256.h;
+    half md = -8.h * xb->d;
+    ushort mask0 = il ? 0x00F0 : 0x000F;
+    ushort mask1 = mask0 << 8;
+
+    reg->s0 = d1 * (qs[0] & mask0) + md;
+    reg->s1 = d2 * (qs[0] & mask1) + md;
+
+    reg->s2 = d1 * (qs[1] & mask0) + md;
+    reg->s3 = d2 * (qs[1] & mask1) + md;
+
+    reg->s4 = d1 * (qs[2] & mask0) + md;
+    reg->s5 = d2 * (qs[2] & mask1) + md;
+
+    reg->s6 = d1 * (qs[3] & mask0) + md;
+    reg->s7 = d2 * (qs[3] & mask1) + md;
+
+    reg->s8 = d1 * (qs[4] & mask0) + md;
+    reg->s9 = d2 * (qs[4] & mask1) + md;
+
+    reg->sa = d1 * (qs[5] & mask0) + md;
+    reg->sb = d2 * (qs[5] & mask1) + md;
+
+    reg->sc = d1 * (qs[6] & mask0) + md;
+    reg->sd = d2 * (qs[6] & mask1) + md;
+
+    reg->se = d1 * (qs[7] & mask0) + md;
+    reg->sf = d2 * (qs[7] & mask1) + md;
+}
+
+//------------------------------------------------------------------------------
+// add
+//------------------------------------------------------------------------------
+
+// general-purpose kernel for addition of two tensors
+// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
+// cons: not very efficient
+kernel void kernel_add(
+        global char * src0,
+        ulong  offset0,
+        global char * src1,
+        ulong  offset1,
+        global char * dst,
+        ulong  offsetd,
+        int   ne00,
+        int   ne01,
+        int   ne02,
+        int   ne03,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int   ne10,
+        int   ne11,
+        int   ne12,
+        int   ne13,
+        ulong nb10,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        int   ne0,
+        int   ne1,
+        int   ne2,
+        int   ne3,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
+) {
+    src0 = src0 + offset0;
+    src1 = src1 + offset1;
+    dst = dst + offsetd;
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0);
+
+    int i13 = i03 % ne13;
+    int i12 = i02 % ne12;
+    int i11 = i01 % ne11;
+
+    global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+    global char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
+
+    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+        const int i10 = i0 % ne10;
+        *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) + *((global float *)(src1_ptr + i10*nb10));
+    }
+}
+
+// assumption: src1 is a row
+// broadcast src1 into src0
+kernel void kernel_add_row(
+        global float4 * src0,
+        ulong  offset0,
+        global float4 * src1,
+        ulong  offset1,
+        global float4 * dst,
+        ulong  offsetd,
+        int ne
+) {
+    src0 = (global float4*)((global char*)src0 + offset0);
+    src1 = (global float4*)((global char*)src1 + offset1);
+    dst = (global float4*)((global char*)dst + offsetd);
+
+    // This performs better than using %.
+    uint gid = get_global_id(0);
+    uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
+    dst[gid] = src0[gid] + src1[idx1];
+}
+
+//------------------------------------------------------------------------------
+// mul
+//------------------------------------------------------------------------------
+kernel void kernel_mul(
+        global char * src0,
+        ulong offset0,
+        global char * src1,
+        ulong offset1,
+        global char * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne03,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne10,
+        int ne11,
+        int ne12,
+        int ne13,
+        ulong nb10,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        int ne0,
+        int ne1,
+        int ne2,
+        int ne3,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
+) {
+    src0 = src0 + offset0;
+    src1 = src1 + offset1;
+    dst  = dst + offsetd;
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0);
+
+    int i13 = i03 % ne13;
+    int i12 = i02 % ne12;
+    int i11 = i01 % ne11;
+
+    global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+    global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+    global char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
+
+    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+        const int i10 = i0 % ne10;
+        *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) * *((global float *)(src1_ptr + i10*nb10));
+    }
+}
+
+// assumption: src1 is a row
+// broadcast src1 into src0
+kernel void kernel_mul_row(
+        global float4 * src0,
+        ulong offset0,
+        global float4 * src1,
+        ulong offset1,
+        global float4 * dst,
+        ulong offsetd,
+        int ne
+) {
+    src0 = (global float4*)((global char*)src0 + offset0);
+    src1 = (global float4*)((global char*)src1 + offset1);
+    dst = (global float4*)((global char*)dst + offsetd);
+
+    // This performs better than using %.
+    uint gid = get_global_id(0);
+    uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
+    dst[gid] = src0[gid] * src1[idx1];
+}
+
+//------------------------------------------------------------------------------
+// scale
+//------------------------------------------------------------------------------
+kernel void kernel_scale(
+        global float4 * src0,
+        ulong offset0,
+        global float4 * dst,
+        ulong offsetd,
+        float scale
+) {
+    src0 = (global float4*)((global char*)src0 + offset0);
+    dst = (global float4*)((global char*)dst + offsetd);
+    dst[get_global_id(0)] = src0[get_global_id(0)] * scale;
+}
+
+//------------------------------------------------------------------------------
+// gelu
+//------------------------------------------------------------------------------
+#define GELU_COEF_A     0.044715f
+#define SQRT_2_OVER_PI  0.79788456080286535587989211986876f
+
+kernel void kernel_gelu(
+    global float * src0,
+    ulong offset0,
+    global float * dst,
+    ulong offsetd
+) {
+    src0 = (global float*)((global char*)src0 + offset0);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    float x = src0[get_global_id(0)];
+
+    dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_4(
+    global float4 * src0,
+    ulong offset0,
+    global float4 * dst,
+    ulong offsetd
+) {
+    src0 = (global float4*)((global char*)src0 + offset0);
+    dst = (global float4*)((global char*)dst + offsetd);
+
+    float4 x = src0[get_global_id(0)];
+
+    dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+//------------------------------------------------------------------------------
+// silu
+//------------------------------------------------------------------------------
+kernel void kernel_silu(
+        global float * src0,
+        ulong offset0,
+        global float * dst,
+        ulong offsetd
+) {
+    src0 = (global float*)((global char*)src0 + offset0);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    float x = src0[get_global_id(0)];
+    dst[get_global_id(0)] = x / (1.0f + exp(-x));
+}
+
+kernel void kernel_silu_4(
+        global float4 * src0,
+        ulong offset0,
+        global float4 * dst,
+        ulong offsetd
+) {
+    src0 = (global float4*)((global char*)src0 + offset0);
+    dst = (global float4*)((global char*)dst + offsetd);
+
+    float4 x = src0[get_global_id(0)];
+    dst[get_global_id(0)] = x / (1.0f + exp(-x));
+}
+
+//------------------------------------------------------------------------------
+// relu
+//------------------------------------------------------------------------------
+kernel void kernel_relu(
+        global float * src0,
+        ulong offset0,
+        global float * dst,
+        ulong offsetd
+) {
+    src0 = (global float*)((global char*)src0 + offset0);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    dst[get_global_id(0)] = fmax(0.0f, src0[get_global_id(0)]);
+}
+
+//------------------------------------------------------------------------------
+// clamp
+//------------------------------------------------------------------------------
+kernel void kernel_clamp(
+        global float * src0,
+        ulong offset0,
+        global float * dst,
+        ulong offsetd,
+        float min,
+        float max
+) {
+    src0 = (global float*)((global char*)src0 + offset0);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    dst[get_global_id(0)] = src0[get_global_id(0)] < min ?
+        min :
+        (src0[get_global_id(0)] > max ? max : src0[get_global_id(0)]);
+}
+
+//------------------------------------------------------------------------------
+// norm
+//------------------------------------------------------------------------------
+kernel void kernel_norm(
+        global void * src0,
+        ulong offset0,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        ulong nb01,
+        float eps,
+        local float * sum
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    dst = (global void*)((global char*)dst + offsetd);
+
+    global float * x = (global float *) ((global char *) src0 + get_group_id(0)*nb01);
+
+    // MEAN
+    // parallel sum
+    sum[get_local_id(0)] = 0.0f;
+    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
+        sum[get_local_id(0)] += x[i00];
+    }
+    // reduce
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for (uint i = get_local_size(0)/2; i > 0; i /= 2) {
+        if (get_local_id(0) < i) {
+            sum[get_local_id(0)] += sum[get_local_id(0) + i];
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    float mean  = sum[0] / ne00;
+
+    // recenter and VARIANCE
+    barrier(CLK_LOCAL_MEM_FENCE);
+    global float * y = dst + get_group_id(0)*ne00;
+    sum[get_local_id(0)] = 0.0f;
+    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
+        y[i00] = x[i00] - mean;
+        sum[get_local_id(0)] += y[i00] * y[i00];
+    }
+
+    // reduce
+    barrier(CLK_LOCAL_MEM_FENCE);
+    for (uint i = get_local_size(0)/2; i > 0; i /= 2) {
+        if (get_local_id(0) < i) {
+            sum[get_local_id(0)] += sum[get_local_id(0) + i];
+        }
+        barrier(CLK_LOCAL_MEM_FENCE);
+    }
+    float variance = sum[0] / ne00;
+
+    float scale = 1.0f/sqrt(variance + eps);
+    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
+        y[i00] = y[i00] * scale;
+    }
+}
+
+//------------------------------------------------------------------------------
+// rms_norm
+//------------------------------------------------------------------------------
+// This kernel depends on subgroup size.
+kernel void kernel_rms_norm(
+        global void * src0,
+        ulong offset0,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        ulong nb01,
+        float eps,
+        local float * sum // Note, the size depends on number of subgroups
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    global float4 * x = (global float4 *) ((global char *) src0 + get_group_id(0)*nb01);
+    global float * x_scalar = (global float *) x;
+    float4 sumf = 0;
+    float all_sum = 0;
+
+    // parallel sum
+    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
+        sumf += x[i00] * x[i00];
+    }
+    all_sum = sumf.s0 + sumf.s1 + sumf.s2 + sumf.s3;
+    all_sum = sub_group_reduce_add(all_sum);
+    if (get_sub_group_local_id() == 0) {
+        sum[get_sub_group_id()] = all_sum;
+    }
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+    // broadcast
+    for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
+       if (get_local_id(0) < i) {
+           sum[get_local_id(0)] += sum[get_local_id(0) + i];
+       }
+    }
+    if (get_local_id(0) == 0) {
+        for (int i = 4 * (ne00 / 4); i < ne00; i++) {
+            sum[0] += x_scalar[i];
+        }
+        sum[0] /= ne00;
+    }
+
+    barrier(CLK_LOCAL_MEM_FENCE);
+
+    const float mean  = sum[0];
+    const float scale = 1.0f/sqrt(mean + eps);
+
+    global float4 * y = (global float4 *) (dst + get_group_id(0)*ne00);
+    global float * y_scalar = (global float *) y;
+    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
+        y[i00] = x[i00] * scale;
+    }
+    if (get_local_id(0) == 0) {
+        for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
+            y_scalar[i00] = x_scalar[i00] * scale;
+        }
+    }
+}
+
+//------------------------------------------------------------------------------
+// diag_mask_inf kernels
+//------------------------------------------------------------------------------
+kernel void kernel_diag_mask_inf(
+        global float * src0,
+        ulong offset0,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int n_past
+) {
+    src0 = (global float*)((global char*)src0 + offset0);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int i02 = get_global_id(2);
+    int i01 = get_global_id(1);
+    int i00 = get_global_id(0);
+
+    if (i00 > n_past + i01) {
+        dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
+    } else {
+        dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
+    }
+}
+
+kernel void kernel_diag_mask_inf_8(
+        global float4 * src0,
+        ulong offset0,
+        global float4 * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int n_past
+) {
+    src0 = (global float4*)((global char*)src0 + offset0);
+    dst = (global float4*)((global char*)dst + offsetd);
+
+    int i = 2*get_global_id(0);
+
+    dst[i+0] = src0[i+0];
+    dst[i+1] = src0[i+1];
+    int i4 = 4*i;
+    int i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
+    int i01 = i4/(ne00);      i4 -= i01*ne00;
+    int i00 = i4;
+    for (int k = 3; k >= 0; --k) {
+        if (i00 + 4 + k <= n_past + i01) {
+            break;
+        }
+        (&dst[i+1])[k] = -INFINITY;
+        if (i00 + k > n_past + i01) {
+            (&dst[i])[k] = -INFINITY;
+        }
+    }
+}
+
+//------------------------------------------------------------------------------
+// softmax
+//------------------------------------------------------------------------------
+kernel void kernel_soft_max(
+        global float * src0,
+        ulong offset0,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        float scale,
+        float max_bias,
+        float m0,
+        float m1,
+        int n_head_log2
+) {
+    src0 = (global float*)((global char*)src0 + offset0);
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0);
+
+    global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0;
+    global float * pdst  = dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    float slope = 1.0f;
+
+    // ALiBi
+    if (max_bias > 0.0f) {
+        int h = i02;
+
+        float base = h < n_head_log2 ? m0 : m1;
+        int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+        slope = pow(base, exp);
+    }
+
+    // parallel max
+    float lmax = -INFINITY;
+    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
+        lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
+    }
+    float max = sub_group_reduce_max(lmax);
+
+    // parallel sum
+    float lsum = 0.0f;
+    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
+        float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max);
+        lsum += exp_psrc0;
+        // Remember the result of exp here. exp is expensive, so we really do not
+        // wish to compute it twice.
+        pdst[i00] = exp_psrc0;
+    }
+
+    const float sum = sub_group_reduce_add(lsum);
+
+    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
+        pdst[i00] /= sum;
+    }
+}
+
+#ifdef ADRENO_GPU
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_soft_max_4(
+        global float * src0,
+        ulong offset0,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        float scale,
+        float max_bias,
+        float m0,
+        float m1,
+        int n_head_log2
+) {
+    src0 = (global float*)((global char*)src0 + offset0);
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0);
+
+    global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0;
+    global float4 * pdst4 = (global float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+
+    float slope = 1.0f;
+
+    // ALiBi
+    if (max_bias > 0.0f) {
+        int h = i02;
+
+        float base = h < n_head_log2 ? m0 : m1;
+        int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+        slope = pow(base, exp);
+    }
+
+    // parallel max
+    float4 lmax4 = -INFINITY;
+    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
+        lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
+    }
+    float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3));
+
+    const float max = sub_group_reduce_max(lmax);
+
+    // parallel sum
+    float4 lsum4 = 0.0f;
+    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
+        const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max);
+        lsum4 += exp_psrc4;
+        pdst4[i00] = exp_psrc4;
+    }
+    float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3;
+
+    const float sum = sub_group_reduce_add(lsum);
+
+    for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
+        pdst4[i00] /= sum;
+    }
+}
+
+//------------------------------------------------------------------------------
+// kernel_rope
+//------------------------------------------------------------------------------
+float rope_yarn_ramp(float low, float high, int i0) {
+    const float y = (i0 / 2 - low) / max(0.001f, high - low);
+    return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+float2 rope_yarn(
+    float theta_extrap, float freq_scale, float2 corr_dims, int i0, float ext_factor, float mscale
+) {
+    // Get n-d rotational scaling corrected for extrapolation
+    float theta_interp = freq_scale * theta_extrap;
+    float theta = theta_interp;
+    if (ext_factor != 0.0f) {
+        float ramp_mix = rope_yarn_ramp(corr_dims.s0, corr_dims.s1, i0) * ext_factor;
+        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+        // Get n-d magnitude scaling corrected for interpolation
+        mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
+    }
+    return (float2)(cos(theta) * mscale, sin(theta) * mscale);
+}
+
+// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
+// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
+float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
+    return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
+}
+
+float2 rope_yarn_corr_dims(
+    int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow
+) {
+    // start and end correction dims
+    return (float2)(
+        max(0.0f,         floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))),
+        min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)))
+    );
+}
+
+kernel void kernel_rope_norm_f32(
+        global void * src0,
+        ulong offset0,
+        global int * src1,
+        ulong offset1,
+        global float * src2,
+        ulong offset2,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne03,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne0,
+        int ne1,
+        int ne2,
+        int ne3,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3,
+        int n_past,
+        int n_dims,
+        int n_ctx_orig,
+        float freq_base,
+        float freq_scale,
+        float ext_factor,
+        float attn_factor,
+        float beta_fast,
+        float beta_slow
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    src1 = (global int*)((global char*)src1 + offset1);
+    src2 = (global float*)((global char*)src2 + offset2);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int i3 = get_group_id(2);
+    int i2 = get_group_id(1);
+    int i1 = get_group_id(0);
+
+    float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
+
+    global int * pos = src1;
+
+    float theta_base = (float) pos[i2];
+    float inv_ndims = -1.f/n_dims;
+
+    for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
+        if (i0 < n_dims) {
+            int ic = i0/2;
+
+            float theta = theta_base * pow(freq_base, inv_ndims*i0);
+
+            float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
+
+            float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
+
+            global float * src       = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            global float * dst_data  = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+            float x0 = src[0];
+            float x1 = src[1];
+
+            dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
+            dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
+        } else {
+            global float * src      = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            global float * dst_data = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+            dst_data[0] = src[0];
+            dst_data[1] = src[1];
+        }
+    }
+}
+
+kernel void kernel_rope_norm_f16(
+        global void * src0,
+        ulong offset0,
+        global int * src1,
+        ulong offset1,
+        global float * src2,
+        ulong offset2,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne03,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne0,
+        int ne1,
+        int ne2,
+        int ne3,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3,
+        int n_past,
+        int n_dims,
+        int n_ctx_orig,
+        float freq_base,
+        float freq_scale,
+        float ext_factor,
+        float attn_factor,
+        float beta_fast,
+        float beta_slow
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    src1 = (global int*)((global char*)src1 + offset1);
+    src2 = (global float*)((global char*)src2 + offset2);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int i3 = get_group_id(2);
+    int i2 = get_group_id(1);
+    int i1 = get_group_id(0);
+
+    float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
+
+    global int * pos = src1;
+
+    float theta_base = (float) pos[i2];
+    float inv_ndims = -1.f/n_dims;
+
+    for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
+        if (i0 < n_dims) {
+            int ic = i0/2;
+
+            float theta = theta_base * pow(freq_base, inv_ndims*i0);
+
+            float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
+
+            float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
+
+            global half * src       = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            global half * dst_data  = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+            float x0 = src[0];
+            float x1 = src[1];
+
+            dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
+            dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
+        } else {
+            global half * src      = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            global half * dst_data = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+            dst_data[0] = src[0];
+            dst_data[1] = src[1];
+        }
+    }
+}
+
+kernel void kernel_rope_neox_f32(
+        global void * src0,
+        ulong offset0,
+        global int * src1,
+        ulong offset1,
+        global float * src2,
+        ulong offset2,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne03,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne0,
+        int ne1,
+        int ne2,
+        int ne3,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3,
+        int n_past,
+        int n_dims,
+        int n_ctx_orig,
+        float freq_base,
+        float freq_scale,
+        float ext_factor,
+        float attn_factor,
+        float beta_fast,
+        float beta_slow
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    src1 = (global int*)((global char*)src1 + offset1);
+    src2 = (global float*)((global char*)src2 + offset2);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int i3 = get_group_id(2);
+    int i2 = get_group_id(1);
+    int i1 = get_group_id(0);
+
+    float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
+
+    global int * pos = src1;
+
+    float theta_base = (float) pos[i2];
+    float inv_ndims = -1.f/n_dims;
+
+    for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
+        if (i0 < n_dims) {
+            int ic = i0/2;
+
+            const float theta = theta_base * pow(freq_base, inv_ndims*i0);
+
+            const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
+
+            float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
+
+            global float * src      = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+            global float * dst_data = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
+
+            const float x0 = src[0];
+            const float x1 = src[n_dims/2];
+
+            dst_data[0]        = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
+            dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
+        } else {
+            global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            global float * dst_data  = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+            dst_data[0] = src[0];
+            dst_data[1] = src[1];
+        }
+    }
+}
+
+kernel void kernel_rope_neox_f16(
+        global void * src0,
+        ulong offset0,
+        global int * src1,
+        ulong offset1,
+        global float * src2,
+        ulong offset2,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne03,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne0,
+        int ne1,
+        int ne2,
+        int ne3,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3,
+        int n_past,
+        int n_dims,
+        int n_ctx_orig,
+        float freq_base,
+        float freq_scale,
+        float ext_factor,
+        float attn_factor,
+        float beta_fast,
+        float beta_slow
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    src1 = (global int*)((global char*)src1 + offset1);
+    src2 = (global float*)((global char*)src2 + offset2);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int i3 = get_group_id(2);
+    int i2 = get_group_id(1);
+    int i1 = get_group_id(0);
+
+    float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
+
+    global int * pos = src1;
+
+    float theta_base = (float) pos[i2];
+    float inv_ndims = -1.f/n_dims;
+
+    for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
+        if (i0 < n_dims) {
+            int ic = i0/2;
+
+            const float theta = theta_base * pow(freq_base, inv_ndims*i0);
+
+            const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
+
+            float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
+
+            global half * src       = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
+            global half * dst_data  = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
+
+            const float x0 = src[0];
+            const float x1 = src[n_dims/2];
+
+            dst_data[0]        = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
+            dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
+        } else {
+            global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+            global half * dst_data  = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+
+            dst_data[0] = src[0];
+            dst_data[1] = src[1];
+        }
+    }
+}
+
+//------------------------------------------------------------------------------
+// cpy
+//------------------------------------------------------------------------------
+
+kernel void kernel_cpy_f16_f16(
+        global half * src0,
+        ulong offset0,
+        global half * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne03,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne0,
+        int ne1,
+        int ne2,
+        int ne3,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
+) {
+    src0 = (global half*)((global char*)src0 + offset0);
+    dst = (global half*)((global char*)dst + offsetd);
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0);
+
+    int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    int i3 = n / (ne2*ne1*ne0);
+    int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
+        global const half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+        dst_data[i00] = src[0];
+    }
+}
+
+kernel void kernel_cpy_f16_f32(
+        global half * src0,
+        ulong offset0,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne03,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne0,
+        int ne1,
+        int ne2,
+        int ne3,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
+) {
+
+    src0 = (global half*)((global char*)src0 + offset0);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0);
+
+    int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    int i3 = n / (ne2*ne1*ne0);
+    int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
+        global half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+        dst_data[i00] = src[0];
+    }
+}
+
+kernel void kernel_cpy_f32_f16(
+        global float * src0,
+        ulong offset0,
+        global half * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne03,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne0,
+        int ne1,
+        int ne2,
+        int ne3,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
+) {
+    src0 = (global float*)((global char*)src0 + offset0);
+    dst = (global half*)((global char*)dst + offsetd);
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0);
+
+    int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    int i3 = n / (ne2*ne1*ne0);
+    int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
+        global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        dst_data[i00] = src[0];
+    }
+}
+
+kernel void kernel_cpy_f32_f32(
+        global float * src0,
+        ulong offset0,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne03,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne0,
+        int ne1,
+        int ne2,
+        int ne3,
+        ulong nb0,
+        ulong nb1,
+        ulong nb2,
+        ulong nb3
+) {
+    src0 = (global float*)((global char*)src0 + offset0);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0);
+
+    int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+    int i3 = n / (ne2*ne1*ne0);
+    int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+    int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+    int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+    global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
+        global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+
+        dst_data[i00] = src[0];
+    }
+}
+
+//------------------------------------------------------------------------------
+// get_rows
+//------------------------------------------------------------------------------
+kernel void kernel_get_rows_f32(
+        global void * src0,
+        ulong offset0,
+        global int * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        ulong nb01,
+        ulong nb02,
+        int ne10,
+        ulong nb10,
+        ulong nb11,
+        ulong nb1,
+        ulong nb2
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    src1 = (global int*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int i10 = get_group_id(0);
+    int i11 = get_group_id(1);
+
+    int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    int i02 = i11;
+
+    for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
+        ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] =
+            ((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind];
+    }
+}
+
+kernel void kernel_get_rows_f16(
+        global void * src0,
+        ulong offset0,
+        global int * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        ulong nb01,
+        ulong nb02,
+        int ne10,
+        ulong nb10,
+        ulong nb11,
+        ulong nb1,
+        ulong nb2
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    src1 = (global int*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int i10 = get_group_id(0);
+    int i11 = get_group_id(1);
+
+    int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    int i02 = i11;
+
+    for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
+        ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] =
+            ((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind];
+    }
+}
+
+kernel void kernel_get_rows_q4_0(
+        global void * src0,
+        ulong offset0,
+        global int * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        ulong nb01,
+        ulong nb02,
+        int ne10,
+        ulong nb10,
+        ulong nb11,
+        ulong nb1,
+        ulong nb2
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    src1 = (global int*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    const int NL = 2;
+
+    int i10 = get_group_id(0);
+    int i11 = get_group_id(1);
+
+    int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
+
+    int i02 = i11;
+
+    for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) {
+        float16 temp;
+        dequantize_q4_0_f32(
+            ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp);
+        *(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
+    }
+}
+
+//------------------------------------------------------------------------------
+// mul_mat_f32_f32
+//------------------------------------------------------------------------------
+#define N_F32_F32 4
+
+kernel void kernel_mul_mat_f32_f32(
+        global char * src0,
+        ulong offset0,
+        global char * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne10,
+        int ne11,
+        int ne12,
+        ulong nb10,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src0 = (global char*)((global char*)src0 + offset0);
+    src1 = (global char*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int r0 = get_group_id(0);
+    int rb = get_group_id(1)*N_F32_F32;
+    int im = get_group_id(2);
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+
+    global float * x = (global float *) (src0 + offset_src0);
+
+    if (ne00 < 128) {
+        for (int row = 0; row < N_F32_F32; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            ulong offset_src1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+            global float * y = (global float *) (src1 + offset_src1);
+
+            float sumf = 0;
+            for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {
+                sumf += (float) x[i] * (float) y[i];
+            }
+
+            float all_sum = sub_group_reduce_add(sumf);
+            if (get_sub_group_local_id() == 0) {
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    } else {
+        global float4 * x4 = (global float4 *)x;
+        for (int row = 0; row < N_F32_F32; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            ulong offset_src1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+            global float  * y  = (global float  *) (src1 + offset_src1);
+            global float4 * y4 = (global float4 *) y;
+
+            float sumf = 0;
+            for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
+                sumf += (float) x4[i].s0 * y4[i].s0;
+                sumf += (float) x4[i].s1 * y4[i].s1;
+                sumf += (float) x4[i].s2 * y4[i].s2;
+                sumf += (float) x4[i].s3 * y4[i].s3;
+            }
+
+            float all_sum = sub_group_reduce_add(sumf);
+            if (get_sub_group_local_id() == 0) {
+                for (int i = 4*(ne00/4); i < ne00; ++i) {
+                    all_sum += (float) x[i] * y[i];
+                }
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    }
+}
+
+//------------------------------------------------------------------------------
+// mul_mat_f16_f16
+//------------------------------------------------------------------------------
+#define N_F16_F16 4
+
+kernel void kernel_mul_mat_f16_f16(
+        global char * src0,
+        ulong offset0,
+        global char * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne10,
+        int ne11,
+        int ne12,
+        ulong nb10,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3)
+{
+    src0 = (global char*)((global char*)src0 + offset0);
+    src1 = (global char*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int r0 = get_group_id(0);
+    int rb = get_group_id(1)*N_F16_F16;
+    int im = get_group_id(2);
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+
+    global half * x = (global half *) (src0 + offset_src0);
+
+    if (ne00 < 128) {
+        for (int row = 0; row < N_F16_F16; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            ulong offset_src1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+            global half * y = (global half *) (src1 + offset_src1);
+
+            float sumf = 0;
+            for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {
+                sumf += (half) x[i] * (half) y[i];
+            }
+
+            float all_sum = sub_group_reduce_add(sumf);
+            if (get_sub_group_local_id() == 0) {
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    } else {
+        global half4 * x4 = (global half4 *)x;
+        for (int row = 0; row < N_F16_F16; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            ulong offset_src1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+            global half  * y  = (global half  *) (src1 + offset_src1);
+            global half4 * y4 = (global half4 *) y;
+
+            float sumf = 0;
+            for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
+                sumf += (half) x4[i].s0 * y4[i].s0;
+                sumf += (half) x4[i].s1 * y4[i].s1;
+                sumf += (half) x4[i].s2 * y4[i].s2;
+                sumf += (half) x4[i].s3 * y4[i].s3;
+            }
+
+            float all_sum = sub_group_reduce_add(sumf);
+            if (get_sub_group_local_id() == 0) {
+                for (int i = 4*(ne00/4); i < ne00; ++i) {
+                    all_sum += (half) x[i] * y[i];
+                }
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    }
+}
+
+//------------------------------------------------------------------------------
+// mul_mat_f16_f32_1row
+//------------------------------------------------------------------------------
+kernel void kernel_mul_mat_f16_f32_1row(
+        global char * src0,
+        ulong offset0,
+        global char * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne10,
+        int ne11,
+        int ne12,
+        ulong nb10,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src0 = (global char*)((global char*)src0 + offset0);
+    src1 = (global char*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    ulong offset_src1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+    global half  * x = (global half  *) (src0 + offset_src0);
+    global float * y = (global float *) (src1 + offset_src1);
+
+    float sumf = 0;
+    if (ne00 < 128) {
+        for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {
+            sumf += (float) x[i] * (float) y[i];
+        }
+        float all_sum = sub_group_reduce_add(sumf);
+        if (get_sub_group_local_id() == 0) {
+            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+        }
+    } else {
+        global half4  * x4 = (global half4  *) x;
+        global float4 * y4 = (global float4 *) y;
+        for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
+            sumf += (float) x4[i].s0 * y4[i].s0;
+            sumf += (float) x4[i].s1 * y4[i].s1;
+            sumf += (float) x4[i].s2 * y4[i].s2;
+            sumf += (float) x4[i].s3 * y4[i].s3;
+        }
+        float all_sum = sub_group_reduce_add(sumf);
+        if (get_sub_group_local_id() == 0) {
+            for (int i = 4*(ne00/4); i < ne00; ++i) {
+                all_sum += (float) x[i] * y[i];
+            }
+            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+        }
+    }
+
+}
+
+//------------------------------------------------------------------------------
+// mul_mat_f16_f32
+//------------------------------------------------------------------------------
+#define N_F16_F32 4
+
+#ifdef ADRENO_GPU
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mat_f16_f32(
+        global char * src0,
+        ulong offset0,
+        global char * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne10,
+        int ne11,
+        int ne12,
+        ulong nb10,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src0 = (global char*)((global char*)src0 + offset0);
+    src1 = (global char*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int r0 = get_group_id(0);
+    int rb = get_group_id(1)*N_F16_F32;
+    int im = get_group_id(2);
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+
+    global half * x = (global half *) (src0 + offset_src0);
+
+    if (ne00 < 128) {
+        for (int row = 0; row < N_F16_F32; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            ulong offset_src1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+            global float * y = (global float *) (src1 + offset_src1);
+
+            float sumf = 0;
+            for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {
+                sumf += convert_float(x[i]) * y[i];
+            }
+
+            float all_sum = sub_group_reduce_add(sumf);
+            if (get_sub_group_local_id() == 0) {
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    } else {
+        global half4 * x4 = (global half4 *)x;
+        for (int row = 0; row < N_F16_F32; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
+
+            ulong offset_src1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+            global float  * y  = (global float  *) (src1 + offset_src1);
+            global float4 * y4 = (global float4 *) y;
+
+            float sumf = 0;
+            for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
+                sumf += convert_float(x4[i].s0) * y4[i].s0;
+                sumf += convert_float(x4[i].s1) * y4[i].s1;
+                sumf += convert_float(x4[i].s2) * y4[i].s2;
+                sumf += convert_float(x4[i].s3) * y4[i].s3;
+            }
+
+            float all_sum = sub_group_reduce_add(sumf);
+            if (get_sub_group_local_id() == 0) {
+                for (int i = 4*(ne00/4); i < ne00; ++i) {
+                    all_sum += (float) x[i] * y[i];
+                }
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
+        }
+    }
+}
+
+//------------------------------------------------------------------------------
+// mul_mat_f16_f32_l4
+//------------------------------------------------------------------------------
+// Assumes row size (ne00) is a multiple of 4
+#ifdef ADRENO_GPU
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mat_f16_f32_l4(
+        global char * src0,
+        ulong offset0,
+        global char * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        ulong nb00,
+        ulong nb01,
+        ulong nb02,
+        ulong nb03,
+        int ne10,
+        int ne11,
+        int ne12,
+        ulong nb10,
+        ulong nb11,
+        ulong nb12,
+        ulong nb13,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src0 = (global char*)((global char*)src0 + offset0);
+    src1 = (global char*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    int nrows = ne11;
+    int r0 = get_group_id(0);
+    int im = get_group_id(2);
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+
+    global half4 * x4 = (global half4 *) (src0 + offset_src0);
+
+    for (int r1 = 0; r1 < nrows; ++r1) {
+        ulong offset_src1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+        global float4 * y4 = (global float4 *) (src1 + offset_src1);
+
+        float sumf = 0;
+        for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
+            sumf += convert_float(x4[i].s0) * y4[i].s0;
+            sumf += convert_float(x4[i].s1) * y4[i].s1;
+            sumf += convert_float(x4[i].s2) * y4[i].s2;
+            sumf += convert_float(x4[i].s3) * y4[i].s3;
+        }
+
+        float all_sum = sub_group_reduce_add(sumf);
+        if (get_sub_group_local_id() == 0) {
+            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+        }
+    }
+}
+
+//------------------------------------------------------------------------------
+// mul_vec_q_n_f32
+//------------------------------------------------------------------------------
+// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q4 quants begin (0 or QK4_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_4_0_dot_y(
+        global struct block_q4_0 * qb_curr,
+        float sumy,
+        private float * yl,
+        int il
+) {
+    float d = qb_curr->d;
+    float2 acc = 0.f;
+    global ushort * qs = ((global ushort *)qb_curr + 1 + il/2);
+    for (int i = 0; i < 8; i+=2) {
+        acc.s0 += yl[i + 0] * (qs[i / 2] & 0x000F)
+                + yl[i + 1] * (qs[i / 2] & 0x0F00);
+        acc.s1 += yl[i + 8] * (qs[i / 2] & 0x00F0)
+                + yl[i + 9] * (qs[i / 2] & 0xF000);
+    }
+    return d * (sumy * -8.f + acc.s0 + acc.s1);
+}
+
+#ifdef INTEL_GPU
+#define N_DST 4 // each SIMD group works on 4 rows
+#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
+#define N_SIMDWIDTH 16 // assuming SIMD group size is 16
+#elif defined (ADRENO_GPU)
+#define N_DST 4
+#define N_SIMDGROUP 1
+#define N_SIMDWIDTH 64
+#endif
+
+inline void mul_vec_q_n_f32(
+        global void * src0,
+        global float * src1,
+        global float * dst,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+
+    const ulong nb = ne00/QK4_0;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+
+    // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global
+    // id of a SIMD group in the grid.
+    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0;
+    global float             * y = (global float             *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float yl[16];       // src1 vector cache
+    float sumf[N_DST]={0.f};
+
+    int ix = get_sub_group_local_id()/2;
+    int il = 8*(get_sub_group_local_id()%2);
+
+    global float * yb = y + ix * QK4_0 + il;
+
+    // each thread in a SIMD group deals with half a block.
+    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
+        float sumy = 0;
+        for (int i = 0; i < 8; i += 2) {
+            sumy += yb[i] + yb[i+1];
+            yl[i+0] = yb[i+ 0];
+            yl[i+1] = yb[i+ 1]/256.f;
+            sumy += yb[i+16] + yb[i+17];
+            yl[i+8] = yb[i+16]/16.f;
+            yl[i+9] = yb[i+17]/4096.f;
+        }
+
+        for (int row = 0; row < N_DST; row++) {
+            sumf[row] += block_q_4_0_dot_y(x+ib+row*nb, sumy, yl, il);
+        }
+
+        // One thread in a SIMD group (i.e., subgroup) handles a half block,
+        // hence then entire SIMD group handles SIMDWIDTH/2 blocks.
+        // y points to the activation matrix (of type float). Therefore for
+        // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because
+        // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of
+        // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size.
+        yb += QK4_0 * (N_SIMDWIDTH/2);
+    }
+
+    // The above does not work for Adreno - it produces incorrect results for
+    // row = 1, 2, 3 and only row = 0 gives the correct result.
+    // If N_DST is changed, the below array must be initialized accordingly.
+    // This also seems to perform better on Intel.
+    float tot[N_DST] = {
+        sub_group_reduce_add(sumf[0]), sub_group_reduce_add(sumf[1]),
+        sub_group_reduce_add(sumf[2]), sub_group_reduce_add(sumf[3])};
+    for (int row = 0; row < N_DST; ++row) {
+        if (get_sub_group_local_id() == 0 && first_row + row < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot[row];
+        }
+    }
+}
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mat_q4_0_f32(
+        global void * src0,
+        ulong offset0,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
+}
+
+//
+// This variant unrolls the loops and uses vector types instead of pointers.
+// It improves performance on Adreno but not so much on Intel.
+//
+inline float block_q_4_0_dot_y_v(
+        global struct block_q4_0 * qb_curr,
+        float sumy,
+        float16 yl,
+        int il
+) {
+    float d = qb_curr->d;
+    float acc = 0.f;
+    global ushort * qs = ((global ushort *)qb_curr + 1 + il/2);
+
+    acc += yl.s0 * (qs[0] & 0x000F);
+    acc += yl.s1 * (qs[0] & 0x0F00);
+    acc += yl.s8 * (qs[0] & 0x00F0);
+    acc += yl.s9 * (qs[0] & 0xF000);
+
+    acc += yl.s2 * (qs[1] & 0x000F);
+    acc += yl.s3 * (qs[1] & 0x0F00);
+    acc += yl.sa * (qs[1] & 0x00F0);
+    acc += yl.sb * (qs[1] & 0xF000);
+
+    acc += yl.s4 * (qs[2] & 0x000F);
+    acc += yl.s5 * (qs[2] & 0x0F00);
+    acc += yl.sc * (qs[2] & 0x00F0);
+    acc += yl.sd * (qs[2] & 0xF000);
+
+    acc += yl.s6 * (qs[3] & 0x000F);
+    acc += yl.s7 * (qs[3] & 0x0F00);
+    acc += yl.se * (qs[3] & 0x00F0);
+    acc += yl.sf * (qs[3] & 0xF000);
+
+    return d * (sumy * -8.f + acc);
+}
+
+#undef N_DST
+#undef N_SIMDGROUP
+#undef N_SIMDWIDTH
+
+#ifdef INTEL_GPU
+#define N_DST 4 // each SIMD group works on 4 rows
+#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
+#define N_SIMDWIDTH 16 // assuming SIMD group size is 16
+#elif defined (ADRENO_GPU)
+#define N_DST 4
+#define N_SIMDGROUP 1
+#define N_SIMDWIDTH 64
+#endif
+
+inline void mul_vec_q_n_f32_v(
+        global void * src0,
+        global float * src1,
+        global float * dst,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    const ulong nb = ne00/QK4_0;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+
+    // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global
+    // id of a SIMD group in the grid.
+    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0;
+    global float             * y = (global float             *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float16 yl;       // src1 vector cache
+    float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);
+
+    int ix = get_sub_group_local_id()/2;
+    int il = 8*(get_sub_group_local_id()%2);
+
+    global float * yb = y + ix * QK4_0 + il;
+
+    // each thread in a SIMD group deals with half a block.
+    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
+        float sumy = 0;
+
+        sumy += yb[0];
+        sumy += yb[1];
+        sumy += yb[2];
+        sumy += yb[3];
+        sumy += yb[4];
+        sumy += yb[5];
+        sumy += yb[6];
+        sumy += yb[7];
+
+        sumy += yb[16];
+        sumy += yb[17];
+        sumy += yb[18];
+        sumy += yb[19];
+        sumy += yb[20];
+        sumy += yb[21];
+        sumy += yb[22];
+        sumy += yb[23];
+
+
+        yl.s0 = yb[0];
+        yl.s1 = yb[1]/256.f;
+
+        yl.s2 = yb[2];
+        yl.s3 = yb[3]/256.f;
+
+        yl.s4 = yb[4];
+        yl.s5 = yb[5]/256.f;
+
+        yl.s6 = yb[6];
+        yl.s7 = yb[7]/256.f;
+
+        yl.s8 = yb[16]/16.f;
+        yl.s9 = yb[17]/4096.f;
+
+        yl.sa = yb[18]/16.f;
+        yl.sb = yb[19]/4096.f;
+
+        yl.sc = yb[20]/16.f;
+        yl.sd = yb[21]/4096.f;
+
+        yl.se = yb[22]/16.f;
+        yl.sf = yb[23]/4096.f;
+
+        sumf.s0 += block_q_4_0_dot_y_v(x+ib+0*nb, sumy, yl, il);
+        sumf.s1 += block_q_4_0_dot_y_v(x+ib+1*nb, sumy, yl, il);
+        sumf.s2 += block_q_4_0_dot_y_v(x+ib+2*nb, sumy, yl, il);
+        sumf.s3 += block_q_4_0_dot_y_v(x+ib+3*nb, sumy, yl, il);
+
+        // One thread in a SIMD group (i.e., subgroup) handles a half block,
+        // hence then entire SIMD group handles SIMDWIDTH/2 blocks.
+        // y points to the activation matrix (of type float). Therefore for
+        // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because
+        // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of
+        // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size.
+        yb += QK4_0 * (N_SIMDWIDTH/2);
+    }
+
+    // The above does not work for Adreno - it produces incorrect results for
+    // row = 1, 2, 3 and only row = 0 gives the correct result.
+    // If N_DST is changed, the below array must be initialized accordingly.
+    // This also seems to perform better on Intel.
+    float4 tot = (float4)(
+        sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
+        sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
+    );
+
+    if (get_sub_group_local_id() == 0) {
+        if (first_row + 0 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
+        }
+        if (first_row + 1 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
+        }
+        if (first_row + 2 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
+        }
+        if (first_row + 3 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
+        }
+    }
+}
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mat_q4_0_f32_v(
+        global void * src0,
+        ulong offset0,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    mul_vec_q_n_f32_v(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
+}
+
+//------------------------------------------------------------------------------
+// kernel_convert_block_q4_0
+// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA).
+// This kernel does not deshuffle the bits.
+//------------------------------------------------------------------------------
+kernel void kernel_convert_block_q4_0(
+    global struct block_q4_0 * src0,
+    global uchar * dst_q,
+    global half  * dst_d
+) {
+    global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0);
+    global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0);
+    global half  * d = (global half *) dst_d + get_global_id(0);
+
+    *d = b->d;
+
+    for (int i = 0; i < QK4_0/2; ++i) {
+        q[i] = b->qs[i];
+    }
+}
+
+kernel void kernel_restore_block_q4_0(
+    global uchar * src_q,
+    global half  * src_d,
+    global struct block_q4_0 * dst
+) {
+    global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0);
+    global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0);
+    global half  * d = (global half *) src_d + get_global_id(0);
+
+    b->d = *d;
+    for (int i = 0; i < QK4_0/2; ++i) {
+        b->qs[i] = q[i];
+    }
+}
+
+//------------------------------------------------------------------------------
+// mul_vec_q_n_f32_flat
+//
+// This variation uses flat arrays (struct of arrays, SOA) representation for
+// quant tensors.
+//------------------------------------------------------------------------------
+
+// This function requires the original shuffled weights.
+// As a reminder, the original weights are shuffled so that (q[0], q[16]) are
+// packed together in a byte, so are (q[1], q[17]) and so on.
+inline float block_q_4_0_dot_y_flat(
+        global uchar * x,
+        global half  * dh,
+        float sumy,
+        float16 yl,
+        int il
+) {
+    float           d   = *dh;
+    global ushort * qs  = ((global ushort *)x + il/2);
+    float           acc = 0.f;
+
+    acc += yl.s0 * (qs[0] & 0x000F);
+    acc += yl.s1 * (qs[0] & 0x0F00);
+    acc += yl.s8 * (qs[0] & 0x00F0);
+    acc += yl.s9 * (qs[0] & 0xF000);
+
+    acc += yl.s2 * (qs[1] & 0x000F);
+    acc += yl.s3 * (qs[1] & 0x0F00);
+    acc += yl.sa * (qs[1] & 0x00F0);
+    acc += yl.sb * (qs[1] & 0xF000);
+
+    acc += yl.s4 * (qs[2] & 0x000F);
+    acc += yl.s5 * (qs[2] & 0x0F00);
+    acc += yl.sc * (qs[2] & 0x00F0);
+    acc += yl.sd * (qs[2] & 0xF000);
+
+    acc += yl.s6 * (qs[3] & 0x000F);
+    acc += yl.s7 * (qs[3] & 0x0F00);
+    acc += yl.se * (qs[3] & 0x00F0);
+    acc += yl.sf * (qs[3] & 0xF000);
+
+    return d * (sumy * -8.f + acc);
+}
+
+#undef N_DST
+#undef N_SIMDGROUP
+#undef N_SIMDWIDTH
+
+#ifdef INTEL_GPU
+#define N_DST 4 // each SIMD group works on 4 rows
+#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
+#define N_SIMDWIDTH 16 // assuming SIMD group size is 32
+#elif defined (ADRENO_GPU)
+#define N_DST 4
+#define N_SIMDGROUP 1
+#define N_SIMDWIDTH 64
+#endif
+
+inline void mul_vec_q_n_f32_flat(
+        global uchar * src0_q,
+        global half  * src0_d,
+        global float * src1,
+        global float * dst,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    const ulong nb = ne00/QK4_0;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+
+    // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of
+    // a SIMD group in the grid. Each SIMD group produces N_DST values in the
+    // result, hence uses nb blocks, i.e., the offset becomes first_row*nb.
+    // Currently with llama2 7B, im is always 0.
+    // TODO: how to handle im/gqa*(nb*ne0)?
+    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    // The number of scales is the same as the number of blocks.
+    ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    // Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
+    ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;
+
+    global uchar * x = (global uchar *) src0_q + offset0_q;
+    global half  * d = (global half  *) src0_d + offset0_d;
+    global float * y = (global float *) src1   + r1*ne10 + im*ne00*ne1;
+
+    float16 yl;
+    float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);
+
+    int ix = get_sub_group_local_id()/2;
+    int il = 8*(get_sub_group_local_id()%2);
+
+    global float * yb = y + ix*QK4_0 + il;
+
+    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
+        float sumy = 0.f;
+
+        sumy += yb[0];
+        sumy += yb[1];
+        sumy += yb[2];
+        sumy += yb[3];
+        sumy += yb[4];
+        sumy += yb[5];
+        sumy += yb[6];
+        sumy += yb[7];
+
+        sumy += yb[16];
+        sumy += yb[17];
+        sumy += yb[18];
+        sumy += yb[19];
+        sumy += yb[20];
+        sumy += yb[21];
+        sumy += yb[22];
+        sumy += yb[23];
+
+        yl.s0 = yb[0];
+        yl.s1 = yb[1]/256.f;
+
+        yl.s2 = yb[2];
+        yl.s3 = yb[3]/256.f;
+
+        yl.s4 = yb[4];
+        yl.s5 = yb[5]/256.f;
+
+        yl.s6 = yb[6];
+        yl.s7 = yb[7]/256.f;
+
+        yl.s8 = yb[16]/16.f;
+        yl.s9 = yb[17]/4096.f;
+
+        yl.sa = yb[18]/16.f;
+        yl.sb = yb[19]/4096.f;
+
+        yl.sc = yb[20]/16.f;
+        yl.sd = yb[21]/4096.f;
+
+        yl.se = yb[22]/16.f;
+        yl.sf = yb[23]/4096.f;
+
+        sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il);
+        sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il);
+        sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il);
+        sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il);
+
+        yb += QK4_0 * (N_SIMDWIDTH/2);
+    }
+
+    float4 tot = (float4)(
+        sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
+        sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
+    );
+
+    if (get_sub_group_local_id() == 0) {
+        if (first_row + 0 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
+        }
+        if (first_row + 1 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
+        }
+        if (first_row + 2 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
+        }
+        if (first_row + 3 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
+        }
+    }
+}
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mat_q4_0_f32_flat(
+        global uchar * src0_q,
+        global half  * src0_d,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    mul_vec_q_n_f32_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
+}
+
+//
+// This variant outputs 8 values.
+//
+#undef N_DST
+#undef N_SIMDGROUP
+#undef N_SIMDWIDTH
+
+#ifdef INTEL_GPU
+#define N_DST 8 // each SIMD group works on 8 rows
+#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
+#define N_SIMDWIDTH 16 // assuming SIMD group size is 32
+#elif defined (ADRENO_GPU)
+#define N_DST 8
+#define N_SIMDGROUP 1
+#define N_SIMDWIDTH 64
+#endif
+
+inline void mul_vec_q_n_f32_8x_flat(
+        global uchar * src0_q,
+        global half  * src0_d,
+        global float * src1,
+        global float * dst,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    const ulong nb = ne00/QK4_0;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+
+    // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of
+    // a SIMD group in the grid. Each SIMD group produces N_DST values in the
+    // result, hence uses nb blocks, i.e., the offset becomes first_row*nb.
+    // Currently with llama2 7B, im is always 0.
+    // TODO: how to handle im/gqa*(nb*ne0)?
+    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    // The number of scales is the same as the number of blocks.
+    ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    // Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
+    ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;
+
+    global uchar * x = (global uchar *) src0_q + offset0_q;
+    global half  * d = (global half  *) src0_d + offset0_d;
+    global float * y = (global float *) src1   + r1*ne10 + im*ne00*ne1;
+
+    float16 yl;
+    float8 sumf = 0.f;
+
+    int ix = get_sub_group_local_id()/2;
+    int il = 8*(get_sub_group_local_id()%2);
+
+    global float * yb = y + ix*QK4_0 + il;
+
+    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
+        float sumy = 0.f;
+
+        sumy += yb[0];
+        sumy += yb[1];
+        sumy += yb[2];
+        sumy += yb[3];
+        sumy += yb[4];
+        sumy += yb[5];
+        sumy += yb[6];
+        sumy += yb[7];
+
+        sumy += yb[16];
+        sumy += yb[17];
+        sumy += yb[18];
+        sumy += yb[19];
+        sumy += yb[20];
+        sumy += yb[21];
+        sumy += yb[22];
+        sumy += yb[23];
+
+        yl.s0 = yb[0];
+        yl.s1 = yb[1]/256.f;
+
+        yl.s2 = yb[2];
+        yl.s3 = yb[3]/256.f;
+
+        yl.s4 = yb[4];
+        yl.s5 = yb[5]/256.f;
+
+        yl.s6 = yb[6];
+        yl.s7 = yb[7]/256.f;
+
+        yl.s8 = yb[16]/16.f;
+        yl.s9 = yb[17]/4096.f;
+
+        yl.sa = yb[18]/16.f;
+        yl.sb = yb[19]/4096.f;
+
+        yl.sc = yb[20]/16.f;
+        yl.sd = yb[21]/4096.f;
+
+        yl.se = yb[22]/16.f;
+        yl.sf = yb[23]/4096.f;
+
+        sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il);
+        sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il);
+        sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il);
+        sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il);
+
+        sumf.s4 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il);
+        sumf.s5 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il);
+        sumf.s6 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il);
+        sumf.s7 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il);
+
+        yb += QK4_0 * (N_SIMDWIDTH/2);
+    }
+
+    float8 tot = (float8)(
+        sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
+        sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3),
+        sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5),
+        sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7)
+    );
+
+    if (get_sub_group_local_id() == 0) {
+        if (first_row + 0 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
+        }
+        if (first_row + 1 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
+        }
+        if (first_row + 2 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
+        }
+        if (first_row + 3 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
+        }
+
+        if (first_row + 4 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;
+        }
+        if (first_row + 5 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;
+        }
+        if (first_row + 6 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;
+        }
+        if (first_row + 7 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;
+        }
+    }
+}
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mat_q4_0_f32_8x_flat(
+        global uchar * src0_q,
+        global half  * src0_d,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    mul_vec_q_n_f32_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
+}
diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl
new file mode 100644
index 00000000000..e2024332f81
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl
@@ -0,0 +1,106 @@
+//------------------------------------------------------------------------------
+// This file is contains additional kernels for data conversion.
+// These kernels are used when loading the model, so its performance is less
+// important.
+//------------------------------------------------------------------------------
+#ifdef cl_khr_fp16
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#elif defined(cl_amd_fp16)
+#pragma OPENCL EXTENSION cl_amd_fp16 : enable
+#else
+#error "Half precision floating point not supportedby OpenCL implementation on your device."
+#endif
+
+#ifdef cl_khr_subgroups
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+#elif defined(cl_intel_subgroups)
+#pragma OPENCL EXTENSION cl_intel_subgroups : enable
+#else
+#error "Subgroup not supported on your device."
+#endif
+
+#ifdef cl_intel_required_subgroup_size
+// Always use subgroup size of 32 on Intel.
+#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
+#define INTEL_GPU 1
+#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
+#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
+#elif defined(cl_qcom_reqd_sub_group_size)
+// Always use subgroups size of 64 on Adreno.
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define ADRENO_GPU 1
+#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#else
+// TODO: do not know how to choose subgroup size on other GPUs.
+#error "Selecting subgroup size is not supported on your device."
+#endif
+
+#define QK4_0                   32
+#define QR4_0                   2
+#define QK4_1                   32
+#define QR4_1                   2
+#define QK5_0                   32
+#define QR5_0                   2
+#define QK5_1                   32
+#define QR5_1                   2
+#define QK8_0                   32
+#define QR8_0                   1
+#define QK_K                    256
+#define K_QUANTS_PER_ITERATION  2
+
+typedef char int8_t;
+typedef uchar uint8_t;
+typedef short int16_t;
+typedef ushort uint16_t;
+typedef int int32_t;
+typedef uint uint32_t;
+
+//------------------------------------------------------------------------------
+// block_q4_0
+//------------------------------------------------------------------------------
+struct block_q4_0
+{
+    half d;
+    uint8_t qs[QK4_0 / 2];
+};
+
+//------------------------------------------------------------------------------
+// mul_vec_q_n_f32_flat_noshuffle
+//
+// This variation uses flat arrays (struct of arrays, SOA) representation for
+// quant tensors. It also uses non shuffled bit order for weights.
+//
+// The shuffled version is kept in the original file because moving it here
+// seems to result in worse performance for adreno.
+//------------------------------------------------------------------------------
+
+kernel void kernel_convert_block_q4_0_noshuffle(
+    global struct block_q4_0 * src0,
+    global uchar * dst_q,
+    global half  * dst_d
+) {
+    global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0);
+    global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0);
+    global half  * d = (global half *) dst_d + get_global_id(0);
+
+    *d = b->d;
+    for (int i = 0; i < QK4_0/4; ++i) {
+        uchar x0 = b->qs[2*i + 0];
+        uchar x1 = b->qs[2*i + 1];
+
+        q[i + 0      ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4);
+        q[i + QK4_0/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0);
+
+#ifdef ADRENO_GPU
+        // Workaround for adreno - must have the following printf statement for
+        // the kernel to work properly. Otherwise it produces incorrect result.
+        // convert_uchar above also seems necessary.
+        // Compare against a large number so that it does not print anything.
+        // get_sub_group_local_id() also works.
+        if (get_global_id(0) == 65536*4096) {
+            printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0));
+        }
+#endif
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl
new file mode 100644
index 00000000000..5e195411d69
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl
@@ -0,0 +1,265 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable
+#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable
+#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+
+// assume
+#define QK4_0 32
+#define N_SIMDGROUP 4
+
+#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, y) \
+    float shared_y; \
+    shared_y = sub_group_broadcast(y.s0, 0); \
+    total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s1, 0); \
+    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s2, 0); \
+    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s3, 0); \
+    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s4, 0); \
+    total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s5, 0); \
+    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s6, 0); \
+    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s7, 0); \
+    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s0, 1); \
+    total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s1, 1); \
+    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s2, 1); \
+    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s3, 1); \
+    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s4, 1); \
+    total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s5, 1); \
+    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s6, 1); \
+    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s7, 1); \
+    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+
+
+#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, y) \
+    shared_y = sub_group_broadcast(y.s0, 2); \
+    total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s1, 2); \
+    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s2, 2); \
+    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s3, 2); \
+    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s4, 2); \
+    total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s5, 2); \
+    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s6, 2); \
+    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s7, 2); \
+    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s0, 3); \
+    total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s1, 3); \
+    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s2, 3); \
+    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s3, 3); \
+    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s4, 3); \
+    total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s5, 3); \
+    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s6, 3); \
+    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s7, 3); \
+    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+
+
+#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, y) \
+    float8 shared_y; \
+    shared_y = sub_group_broadcast(y, 0); \
+    total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
+    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
+    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
+    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
+    total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
+    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
+    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
+    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
+    total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
+    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
+    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
+    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
+    total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
+    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
+    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
+    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
+    shared_y = sub_group_broadcast(y, 1); \
+    total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
+    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
+    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
+    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
+    total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
+    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
+    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
+    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
+    total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
+    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
+    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
+    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
+    total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
+    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
+    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
+    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
+
+
+#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, y) \
+    shared_y = sub_group_broadcast(y, 2); \
+    total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
+    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
+    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
+    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
+    total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
+    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
+    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
+    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
+    total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
+    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
+    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
+    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
+    total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
+    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
+    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
+    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
+    shared_y = sub_group_broadcast(y, 3); \
+    total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
+    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
+    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
+    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
+    total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
+    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
+    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
+    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
+    total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
+    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
+    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
+    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
+    total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
+    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
+    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
+    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
+
+
+__attribute__((qcom_reqd_sub_group_size("full")))
+__kernel void kernel_gemv_noshuffle(
+        __read_only  image1d_buffer_t src0_q,  // quantized A
+        global half2  * src0_d,  // A scales
+        __read_only  image1d_buffer_t src1,    // B
+        ulong offset1,            // offset to B (0)
+        global float * dst,     // C
+        ulong offsetd,            // offset to C (0)
+        uint K,               // K
+        int ne01,               // M
+        int ne02,               // 1
+        int ne10,               // K
+        int ne12,               // 1
+        int ne0,                // M
+        int ne1,                // N
+        int r2,                 // 1
+        int r3)
+{
+    uint groupId = get_local_id(1);
+    uint gid     = get_global_id(0);
+    ushort slid    = get_sub_group_local_id();
+
+    __private uint4     regA;
+    __private half2     regS;
+    __private float8    regB;
+
+    __private float2 totalSum = (float2)(0.0f);
+
+    // loop along K in block granularity, skip 4 blocks every iter
+    for (uint k = groupId; k < (K / QK4_0); k += N_SIMDGROUP) {
+        regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows
+        // first 4 fibers in each wave load 8 B values to its private scope
+        if (slid < 4) {
+            regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));
+            regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));
+        }
+
+        // load half weights for two blocks in consecutive rows
+        regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
+        regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
+        regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
+        regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
+#ifdef VECTOR_SUB_GROUP_BROADCAT
+        dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB);
+#else
+        dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB);
+#endif // VECTOR_SUB_GROUP_BROADCAT
+
+        regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;
+        regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;
+        regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
+        regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
+#ifdef VECTOR_SUB_GROUP_BROADCAT
+        dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB);
+#else
+        dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB);
+#endif // VECTOR_SUB_GROUP_BROADCAT
+    }
+
+    // reduction in local memory, assumes #wave=4
+    __local float2 reduceLM[SIMDGROUP_WIDTH * 3];
+    if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum;
+    if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum;
+    if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
+    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
+    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
+
+    // 2 outputs per fiber in wave 0
+    if (groupId == 0) {
+        dst = (global float*)((global char*)dst + offsetd);
+        vstore2(totalSum, 0, &(dst[gid * 2]));
+    }
+
+}
diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl
new file mode 100644
index 00000000000..5bdd4d06763
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl
@@ -0,0 +1,271 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable
+#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable
+#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+
+// assume
+#define QK4_0 32
+#define N_SIMDGROUP 4
+
+#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, y) \
+    float shared_y; \
+    shared_y = sub_group_broadcast(y.s0, 0); \
+    total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s1, 0); \
+    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s2, 0); \
+    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s3, 0); \
+    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s4, 0); \
+    total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s5, 0); \
+    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s6, 0); \
+    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s7, 0); \
+    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s0, 1); \
+    total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s1, 1); \
+    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s2, 1); \
+    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s3, 1); \
+    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s4, 1); \
+    total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s5, 1); \
+    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s6, 1); \
+    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s7, 1); \
+    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+
+
+#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, y) \
+    shared_y = sub_group_broadcast(y.s0, 2); \
+    total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s1, 2); \
+    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s2, 2); \
+    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s3, 2); \
+    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s4, 2); \
+    total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s5, 2); \
+    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s6, 2); \
+    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s7, 2); \
+    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s0, 3); \
+    total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s1, 3); \
+    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s2, 3); \
+    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s3, 3); \
+    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s4, 3); \
+    total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s5, 3); \
+    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s6, 3); \
+    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y; \
+    shared_y = sub_group_broadcast(y.s7, 3); \
+    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y; \
+    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y; \
+
+
+#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, y) \
+    float8 shared_y; \
+    shared_y = sub_group_broadcast(y, 0); \
+    total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
+    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
+    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
+    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
+    total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
+    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
+    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
+    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
+    total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
+    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
+    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
+    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
+    total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
+    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
+    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
+    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
+    shared_y = sub_group_broadcast(y, 1); \
+    total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
+    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
+    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
+    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
+    total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
+    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
+    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
+    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
+    total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
+    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
+    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
+    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
+    total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
+    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
+    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
+    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
+
+
+#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, y) \
+    shared_y = sub_group_broadcast(y, 2); \
+    total_sums.s0 += ((bits4.s0 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
+    total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
+    total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
+    total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
+    total_sums.s0 += ((bits4.s2 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
+    total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
+    total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
+    total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
+    total_sums.s1 += ((bits4.s1 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
+    total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
+    total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
+    total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
+    total_sums.s1 += ((bits4.s3 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
+    total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
+    total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
+    total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
+    shared_y = sub_group_broadcast(y, 3); \
+    total_sums.s0 += ((bits4.s4 & 0x000F) - 8) * scale.s0 * shared_y.s0; \
+    total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s1; \
+    total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s2; \
+    total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s3; \
+    total_sums.s0 += ((bits4.s6 & 0x000F) - 8) * scale.s0 * shared_y.s4; \
+    total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) - 8) * scale.s0 * shared_y.s5; \
+    total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) - 8) * scale.s0 * shared_y.s6; \
+    total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) - 8) * scale.s0 * shared_y.s7; \
+    total_sums.s1 += ((bits4.s5 & 0x000F) - 8) * scale.s1 * shared_y.s0; \
+    total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s1; \
+    total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s2; \
+    total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s3; \
+    total_sums.s1 += ((bits4.s7 & 0x000F) - 8) * scale.s1 * shared_y.s4; \
+    total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) - 8) * scale.s1 * shared_y.s5; \
+    total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) - 8) * scale.s1 * shared_y.s6; \
+    total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) - 8) * scale.s1 * shared_y.s7; \
+
+
+__attribute__((qcom_reqd_sub_group_size("full")))
+__kernel void kernel_gemv_noshuffle(
+        __read_only  image1d_buffer_t src0_q,  // quantized A
+        global half2  * src0_d,  // A scales
+        __read_only  image1d_buffer_t src1,    // B
+        ulong offset1,            // offset to B (0)
+        global float * dst,     // C
+        ulong offsetd,            // offset to C (0)
+        int ne00,               // K
+        int ne01,               // M
+        int ne02,               // 1
+        int ne10,               // K
+        int ne12,               // 1
+        int ne0,                // M
+        int ne1,                // N
+        int r2,                 // 1
+        int r3)
+{
+    uint groupId = get_local_id(1);
+    uint gid     = get_global_id(0);
+    ushort slid    = get_sub_group_local_id();
+
+    uint K = ne00;
+    uint M = ne01;
+
+    uint LINE_STRIDE_A = M / 2;
+    uint BLOCK_STRIDE_A = N_SIMDGROUP * M;
+
+    __private uint4     regA;
+    __private half2     regS;
+    __private float8    regB;
+
+    __private float2 totalSum = (float2)(0.0f);
+
+    // loop along K in block granularity, skip 4 blocks every iter
+    for (uint k = groupId; k < (K / QK4_0); k += N_SIMDGROUP) {
+        regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows
+        // first 4 fibers in each wave load 8 B values to its private scope
+        if (slid < 4) {
+            regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));
+            regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));
+        }
+
+        // load half weights for two blocks in consecutive rows
+        regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
+        regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
+        regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
+        regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
+#ifdef VECTOR_SUB_GROUP_BROADCAT
+        dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB);
+#else
+        dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB);
+#endif // VECTOR_SUB_GROUP_BROADCAT
+
+        regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;
+        regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;
+        regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
+        regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
+#ifdef VECTOR_SUB_GROUP_BROADCAT
+        dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB);
+#else
+        dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB);
+#endif // VECTOR_SUB_GROUP_BROADCAT
+    }
+
+    // reduction in local memory, assumes #wave=4
+    __local float2 reduceLM[SIMDGROUP_WIDTH * 3];
+    if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum;
+    if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum;
+    if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum;
+    barrier(CLK_LOCAL_MEM_FENCE);
+    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
+    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
+    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
+
+    // 2 outputs per fiber in wave 0
+    if (groupId == 0) {
+        dst = (global float*)((global char*)dst + offsetd);
+        vstore2(totalSum, 0, &(dst[gid * 2]));
+    }
+
+}
diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl
new file mode 100644
index 00000000000..e19e9a2f436
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl
@@ -0,0 +1,1225 @@
+//------------------------------------------------------------------------------
+// This file is contains additional mulmat kernels
+// (and potentially other kernels).
+//------------------------------------------------------------------------------
+#ifdef cl_khr_fp16
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#elif defined(cl_amd_fp16)
+#pragma OPENCL EXTENSION cl_amd_fp16 : enable
+#else
+#error "Half precision floating point not supportedby OpenCL implementation on your device."
+#endif
+
+#ifdef cl_khr_subgroups
+#pragma OPENCL EXTENSION cl_khr_subgroups : enable
+#elif defined(cl_intel_subgroups)
+#pragma OPENCL EXTENSION cl_intel_subgroups : enable
+#else
+#error "Subgroup not supported on your device."
+#endif
+
+#ifdef cl_intel_required_subgroup_size
+// Always use subgroup size of 32 on Intel.
+#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
+#define INTEL_GPU 1
+#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
+#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
+#elif defined(cl_qcom_reqd_sub_group_size)
+// Always use subgroups size of 64 on Adreno.
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+#define ADRENO_GPU 1
+#define REQD_SUBGROUP_SIZE_64  __attribute__((qcom_reqd_sub_group_size("half")))
+#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
+#else
+// TODO: do not know how to choose subgroup size on other GPUs.
+#error "Selecting subgroup size is not supported on your device."
+#endif
+
+#define QK4_0                   32
+#define QR4_0                   2
+#define QK4_1                   32
+#define QR4_1                   2
+#define QK5_0                   32
+#define QR5_0                   2
+#define QK5_1                   32
+#define QR5_1                   2
+#define QK8_0                   32
+#define QR8_0                   1
+#define QK_K                    256
+#define K_QUANTS_PER_ITERATION  2
+
+typedef char int8_t;
+typedef uchar uint8_t;
+typedef short int16_t;
+typedef ushort uint16_t;
+typedef int int32_t;
+typedef uint uint32_t;
+
+//------------------------------------------------------------------------------
+// block_q4_0
+//------------------------------------------------------------------------------
+struct block_q4_0
+{
+    half d;
+    uint8_t qs[QK4_0 / 2];
+};
+
+//------------------------------------------------------------------------------
+// block_q6_K
+//------------------------------------------------------------------------------
+// 6-bit quantization
+// weight is represented as x = a * q
+// 16 blocks of 16 elements each
+// Effectively 6.5625 bits per weight
+typedef struct {
+    uint8_t ql[QK_K/2];      // quants, lower 4 bits
+    uint8_t qh[QK_K/4];      // quants, upper 2 bits
+    int8_t  scales[QK_K/16]; // scales, quantized with 8 bits
+    half d;             // super-block scale
+} block_q6_K;
+
+//------------------------------------------------------------------------------
+// These are the variant for matmatmul, based on the matvecmul kernel with
+// flattened block_q4_0.
+//------------------------------------------------------------------------------
+
+// Common dot prod.
+inline float mm_block_q_4_0_dot_y_flat(
+        global uchar * x,
+        global half  * dh,
+        float sumy,
+        float16 yl,
+        int il
+) {
+    float           d   = *dh;
+    global ushort * qs  = ((global ushort *)x + il/2);
+    float           acc = 0.f;
+
+    acc += yl.s0 * (qs[0] & 0x000F);
+    acc += yl.s1 * (qs[0] & 0x0F00);
+    acc += yl.s8 * (qs[0] & 0x00F0);
+    acc += yl.s9 * (qs[0] & 0xF000);
+
+    acc += yl.s2 * (qs[1] & 0x000F);
+    acc += yl.s3 * (qs[1] & 0x0F00);
+    acc += yl.sa * (qs[1] & 0x00F0);
+    acc += yl.sb * (qs[1] & 0xF000);
+
+    acc += yl.s4 * (qs[2] & 0x000F);
+    acc += yl.s5 * (qs[2] & 0x0F00);
+    acc += yl.sc * (qs[2] & 0x00F0);
+    acc += yl.sd * (qs[2] & 0xF000);
+
+    acc += yl.s6 * (qs[3] & 0x000F);
+    acc += yl.s7 * (qs[3] & 0x0F00);
+    acc += yl.se * (qs[3] & 0x00F0);
+    acc += yl.sf * (qs[3] & 0xF000);
+
+    return d * (sumy * -8.f + acc);
+}
+
+#undef N_DST
+#undef N_SIMDGROUP
+#undef N_SIMDWIDTH
+
+#ifdef INTEL_GPU
+#define N_DST 8 // each SIMD group works on 8 rows (in weights matrix)
+#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
+#define N_SIMDWIDTH 16 // assuming SIMD group size is 16
+#elif defined (ADRENO_GPU)
+#define N_DST 8
+#define N_SIMDGROUP 1
+#define N_SIMDWIDTH 64
+#endif
+//
+// This variant performs 1d blocking with 8x output.
+// Eeach simdgroup outputs 8 values on `n0` dim (row in the output matrix).
+//
+inline void mul_mat_q_n_f32_1d_8x_flat(
+        global uchar * src0_q,
+        global half  * src0_d,
+        global float * src1,
+        global float * dst,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    const int nb = ne00/QK4_0;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+
+    // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of
+    // a SIMD group in the grid. Each SIMD group produces N_DST values in the
+    // result, hence uses nb blocks, i.e., the offset becomes first_row*nb.
+    // Currently with llama2 7B, im is always 0.
+    // TODO: how to handle im/gqa*(nb*ne0)?
+    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    // The number of scales is the same as the number of blocks.
+    ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    // Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
+    ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;
+
+    global uchar * x = (global uchar *) src0_q + offset0_q;
+    global half  * d = (global half  *) src0_d + offset0_d;
+    global float * y = (global float *) src1   + r1*ne10 + im*ne00*ne1;
+
+    float16 yl;
+    float8 sumf = (float8)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f);
+
+    int ix = get_sub_group_local_id()/2;
+    int il = 8*(get_sub_group_local_id()%2);
+
+    global float * yb = y + ix*QK4_0 + il;
+
+    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
+        float sumy = 0.f;
+
+        sumy += yb[0];
+        sumy += yb[1];
+        sumy += yb[2];
+        sumy += yb[3];
+        sumy += yb[4];
+        sumy += yb[5];
+        sumy += yb[6];
+        sumy += yb[7];
+
+        sumy += yb[16];
+        sumy += yb[17];
+        sumy += yb[18];
+        sumy += yb[19];
+        sumy += yb[20];
+        sumy += yb[21];
+        sumy += yb[22];
+        sumy += yb[23];
+
+        yl.s0 = yb[0];
+        yl.s1 = yb[1]/256.f;
+
+        yl.s2 = yb[2];
+        yl.s3 = yb[3]/256.f;
+
+        yl.s4 = yb[4];
+        yl.s5 = yb[5]/256.f;
+
+        yl.s6 = yb[6];
+        yl.s7 = yb[7]/256.f;
+
+        yl.s8 = yb[16]/16.f;
+        yl.s9 = yb[17]/4096.f;
+
+        yl.sa = yb[18]/16.f;
+        yl.sb = yb[19]/4096.f;
+
+        yl.sc = yb[20]/16.f;
+        yl.sd = yb[21]/4096.f;
+
+        yl.se = yb[22]/16.f;
+        yl.sf = yb[23]/4096.f;
+
+        sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il);
+        sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il);
+        sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il);
+        sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il);
+
+        sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il);
+        sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il);
+        sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il);
+        sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il);
+
+        yb += QK4_0 * (N_SIMDWIDTH/2);
+    }
+
+    float8 tot = (float8)(
+        sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
+        sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3),
+        sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5),
+        sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7)
+    );
+
+    if (get_sub_group_local_id() == 0) {
+        if (first_row + 0 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
+        }
+        if (first_row + 1 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
+        }
+        if (first_row + 2 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
+        }
+        if (first_row + 3 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
+        }
+
+        if (first_row + 4 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;
+        }
+        if (first_row + 5 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;
+        }
+        if (first_row + 6 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;
+        }
+        if (first_row + 7 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;
+        }
+    }
+}
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mat_q4_0_f32_1d_8x_flat(
+        global uchar * src0_q,
+        global half  * src0_d,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    mul_mat_q_n_f32_1d_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
+}
+
+#undef N_DST
+#undef N_SIMDGROUP
+#undef N_SIMDWIDTH
+
+#ifdef INTEL_GPU
+#define N_DST 16 // each SIMD group works on 8 rows (in weights matrix)
+#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
+#define N_SIMDWIDTH 16 // assuming SIMD group size is 16
+#elif defined (ADRENO_GPU)
+#define N_DST 16
+#define N_SIMDGROUP 1
+#define N_SIMDWIDTH 64
+#endif
+//
+// This variant performs 1d blocking with 16x output.
+// Eeach simdgroup outputs 16 values on `n0` dim (row in the output matrix).
+//
+inline void mul_mat_q_n_f32_1d_16x_flat(
+        global uchar * src0_q,
+        global half  * src0_d,
+        global float * src1,
+        global float * dst,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    const int nb = ne00/QK4_0;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+
+    // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of
+    // a SIMD group in the grid. Each SIMD group produces N_DST values in the
+    // result, hence uses nb blocks, i.e., the offset becomes first_row*nb.
+    // Currently with llama2 7B, im is always 0.
+    // TODO: how to handle im/gqa*(nb*ne0)?
+    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    // The number of scales is the same as the number of blocks.
+    ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    // Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
+    ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;
+
+    global uchar * x = (global uchar *) src0_q + offset0_q;
+    global half  * d = (global half  *) src0_d + offset0_d;
+    global float * y = (global float *) src1   + r1*ne10 + im*ne00*ne1;
+
+    float16 yl;
+    float16 sumf = (float16)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
+                             0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f);
+
+    int ix = get_sub_group_local_id()/2;
+    int il = 8*(get_sub_group_local_id()%2);
+
+    global float * yb = y + ix*QK4_0 + il;
+
+    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
+        float sumy = 0.f;
+
+        sumy += yb[0];
+        sumy += yb[1];
+        sumy += yb[2];
+        sumy += yb[3];
+        sumy += yb[4];
+        sumy += yb[5];
+        sumy += yb[6];
+        sumy += yb[7];
+
+        sumy += yb[16];
+        sumy += yb[17];
+        sumy += yb[18];
+        sumy += yb[19];
+        sumy += yb[20];
+        sumy += yb[21];
+        sumy += yb[22];
+        sumy += yb[23];
+
+        yl.s0 = yb[0];
+        yl.s1 = yb[1]/256.f;
+
+        yl.s2 = yb[2];
+        yl.s3 = yb[3]/256.f;
+
+        yl.s4 = yb[4];
+        yl.s5 = yb[5]/256.f;
+
+        yl.s6 = yb[6];
+        yl.s7 = yb[7]/256.f;
+
+        yl.s8 = yb[16]/16.f;
+        yl.s9 = yb[17]/4096.f;
+
+        yl.sa = yb[18]/16.f;
+        yl.sb = yb[19]/4096.f;
+
+        yl.sc = yb[20]/16.f;
+        yl.sd = yb[21]/4096.f;
+
+        yl.se = yb[22]/16.f;
+        yl.sf = yb[23]/4096.f;
+
+        sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  0*nb*QK4_0/2, d + ib +  0*nb, sumy, yl, il);
+        sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  1*nb*QK4_0/2, d + ib +  1*nb, sumy, yl, il);
+        sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  2*nb*QK4_0/2, d + ib +  2*nb, sumy, yl, il);
+        sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  3*nb*QK4_0/2, d + ib +  3*nb, sumy, yl, il);
+
+        sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  4*nb*QK4_0/2, d + ib +  4*nb, sumy, yl, il);
+        sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  5*nb*QK4_0/2, d + ib +  5*nb, sumy, yl, il);
+        sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  6*nb*QK4_0/2, d + ib +  6*nb, sumy, yl, il);
+        sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  7*nb*QK4_0/2, d + ib +  7*nb, sumy, yl, il);
+
+        sumf.s8 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  8*nb*QK4_0/2, d + ib +  8*nb, sumy, yl, il);
+        sumf.s9 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 +  9*nb*QK4_0/2, d + ib +  9*nb, sumy, yl, il);
+        sumf.sa += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 10*nb*QK4_0/2, d + ib + 10*nb, sumy, yl, il);
+        sumf.sb += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 11*nb*QK4_0/2, d + ib + 11*nb, sumy, yl, il);
+
+        sumf.sc += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 12*nb*QK4_0/2, d + ib + 12*nb, sumy, yl, il);
+        sumf.sd += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 13*nb*QK4_0/2, d + ib + 13*nb, sumy, yl, il);
+        sumf.se += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 14*nb*QK4_0/2, d + ib + 14*nb, sumy, yl, il);
+        sumf.sf += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 15*nb*QK4_0/2, d + ib + 15*nb, sumy, yl, il);
+
+        yb += QK4_0 * (N_SIMDWIDTH/2);
+    }
+
+    float16 tot = (float16)(
+        sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
+        sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3),
+        sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5),
+        sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7),
+
+        sub_group_reduce_add(sumf.s8), sub_group_reduce_add(sumf.s9),
+        sub_group_reduce_add(sumf.sa), sub_group_reduce_add(sumf.sb),
+        sub_group_reduce_add(sumf.sc), sub_group_reduce_add(sumf.sd),
+        sub_group_reduce_add(sumf.se), sub_group_reduce_add(sumf.sf)
+    );
+
+    if (get_sub_group_local_id() == 0) {
+        if (first_row + 0 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
+        }
+        if (first_row + 1 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
+        }
+        if (first_row + 2 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
+        }
+        if (first_row + 3 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
+        }
+
+        if (first_row + 4 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;
+        }
+        if (first_row + 5 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;
+        }
+        if (first_row + 6 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;
+        }
+        if (first_row + 7 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;
+        }
+
+        if (first_row + 8 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 8] = tot.s8;
+        }
+        if (first_row + 9 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 9] = tot.s9;
+        }
+        if (first_row + 10 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 10] = tot.sa;
+        }
+        if (first_row + 11 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 11] = tot.sb;
+        }
+
+        if (first_row + 12 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 12] = tot.sc;
+        }
+        if (first_row + 13 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 13] = tot.sd;
+        }
+        if (first_row + 14 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 14] = tot.se;
+        }
+        if (first_row + 15 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 15] = tot.sf;
+        }
+    }
+}
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mat_q4_0_f32_1d_16x_flat(
+        global uchar * src0_q,
+        global half  * src0_d,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    mul_mat_q_n_f32_1d_16x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
+}
+
+//------------------------------------------------------------------------------
+// kernel_mul_mat_q4_0_f32_flat_v0
+//------------------------------------------------------------------------------
+inline float block_q_4_0_dot_y_flat_v2(
+    half   x,
+    half   d,
+    float  sumy,
+    float4 yl
+) {
+    uchar2 q = as_uchar2(x);
+    float acc = 0.0f;
+
+    acc += (q.s0 & 0x0F) * yl.s0;
+    acc += (q.s1 & 0x0F) * yl.s1;
+
+    acc += (q.s0 & 0xF0) * yl.s2;
+    acc += (q.s1 & 0xF0) * yl.s3;
+
+    return d * (sumy * -8.f + acc);;
+}
+
+inline float block_q_4_0_dot_y_flat_v4(
+    float  x,
+    half   d,
+    float  sumy,
+    float8 yl
+) {
+    uchar4 q = as_uchar4(x);
+    float acc = 0.0f;
+
+    acc += (q.s0 & 0x0F) * yl.s0;
+    acc += (q.s1 & 0x0F) * yl.s1;
+    acc += (q.s2 & 0x0F) * yl.s2;
+    acc += (q.s3 & 0x0F) * yl.s3;
+
+    acc += (q.s0 & 0xF0) * yl.s4;
+    acc += (q.s1 & 0xF0) * yl.s5;
+    acc += (q.s2 & 0xF0) * yl.s6;
+    acc += (q.s3 & 0xF0) * yl.s7;
+
+    return d * (sumy * -8.f + acc);;
+}
+
+inline float block_q_4_0_dot_y_flat_v8(
+    float2  x,
+    half    d,
+    float   sumy,
+    float16 yl
+) {
+    uchar8 q = as_uchar8(x);
+    float acc = 0.0f;
+
+    acc += (q.s0 & 0x0F) * yl.s0;
+    acc += (q.s1 & 0x0F) * yl.s1;
+    acc += (q.s2 & 0x0F) * yl.s2;
+    acc += (q.s3 & 0x0F) * yl.s3;
+    acc += (q.s4 & 0x0F) * yl.s4;
+    acc += (q.s5 & 0x0F) * yl.s5;
+    acc += (q.s6 & 0x0F) * yl.s6;
+    acc += (q.s7 & 0x0F) * yl.s7;
+
+    acc += (q.s0 & 0xF0) * yl.s8;
+    acc += (q.s1 & 0xF0) * yl.s9;
+    acc += (q.s2 & 0xF0) * yl.sa;
+    acc += (q.s3 & 0xF0) * yl.sb;
+    acc += (q.s4 & 0xF0) * yl.sc;
+    acc += (q.s5 & 0xF0) * yl.sd;
+    acc += (q.s6 & 0xF0) * yl.se;
+    acc += (q.s7 & 0xF0) * yl.sf;
+
+    return d * (sumy * -8.f + acc);;
+}
+
+#undef N_DST
+#undef N_SIMDGROUP
+#undef N_SIMDWIDTH
+
+#ifdef INTEL_GPU
+#define THREADS_PER_BLK 4   // Number of threads per block, or each thread process 1/THREADS_PER_BLK of a block
+#define N_DST           4
+#define N_SIMDGROUP     1
+#define N_SIMDWIDTH     16
+#elif defined (ADRENO_GPU)
+#define THREADS_PER_BLK 4
+#define N_DST           4
+#define N_SIMDGROUP     1
+#define N_SIMDWIDTH     64
+#endif
+
+#if THREADS_PER_BLK == 2                // Each thread processes 1/2 block
+#   define ACT_TY                       float16
+#   define Q_BLK_LD_TY                  float2
+#   define block_q_4_0_dot_y_flat       block_q_4_0_dot_y_flat_v8
+#elif THREADS_PER_BLK == 4              // Each thread processes 1/4 block
+#   define ACT_TY                       float8
+#   define Q_BLK_LD_TY                  float
+#   define block_q_4_0_dot_y_flat       block_q_4_0_dot_y_flat_v4
+#elif THREADS_PER_BLK == 8              // Each thread processes 1/8 block
+#   define ACT_TY                       float4
+#   define Q_BLK_LD_TY                  half
+#   define block_q_4_0_dot_y_flat       block_q_4_0_dot_y_flat_v2
+#endif
+
+#define BTYES_PER_THREAD_IN_BLK         (QK4_0/2/THREADS_PER_BLK)
+
+#if N_DST == 2
+#   define  SUM_TY                      float2
+#elif N_DST == 4
+#   define  SUM_TY                      float4
+#elif N_DST == 8
+#   define  SUM_TY                      float8
+#elif N_DST == 16
+#   define  SUM_TY                      float16
+#endif
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mat_q4_0_f32_flat_v0(
+        global uchar * src0_q,
+        global half  * src0_d,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    const int nb = ne00/QK4_0;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+
+    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    // The number of scales is the same as the number of blocks.
+    ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    // Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
+    ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2;
+
+    global uchar * x = (global uchar *) src0_q + offset0_q;
+    global half  * d = (global half  *) src0_d + offset0_d;
+    global float * y = (global float *) src1   + r1*ne10 + im*ne00*ne1;
+
+    int ix = get_sub_group_local_id()/THREADS_PER_BLK;
+    int il = get_sub_group_local_id()%THREADS_PER_BLK;
+
+    global float * yb = y + ix*QK4_0 + BTYES_PER_THREAD_IN_BLK*il;
+
+    // Registers for caching activation
+    ACT_TY yl = 0.f;
+
+    // Registers for caching quants
+    Q_BLK_LD_TY q_blk_0 = 0, q_blk_1 = 0;
+#if N_DST == 4 || N_DST == 8 || N_DST == 16
+    Q_BLK_LD_TY q_blk_2 = 0, q_blk_3 = 0;
+#endif
+#if N_DST == 8 || N_DST == 16
+    Q_BLK_LD_TY q_blk_4 = 0, q_blk_5 = 0, q_blk_6 = 0, q_blk_7 = 0;
+#endif
+
+    // Partial sum
+    SUM_TY sumf = 0.f;
+
+    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/THREADS_PER_BLK) {
+        float sumy = 0.f;
+
+        q_blk_0 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 0*nb*QK4_0/2);
+        q_blk_1 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 1*nb*QK4_0/2);
+#if N_DST == 4 || N_DST == 8 || N_DST == 16
+        q_blk_2 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 2*nb*QK4_0/2);
+        q_blk_3 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 3*nb*QK4_0/2);
+#endif
+#if N_DST == 8 || N_DST == 16
+        q_blk_4 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 4*nb*QK4_0/2));
+        q_blk_5 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 5*nb*QK4_0/2));
+        q_blk_6 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 6*nb*QK4_0/2));
+        q_blk_7 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 7*nb*QK4_0/2));
+#endif
+
+        // Load activation
+#if THREADS_PER_BLK == 2    // Each thread processes 1/2 block
+        yl.s01234567 = *(global float8 *)(yb);
+        yl.s89abcdef = *(global float8 *)(yb + 16);
+
+        sumy += yl.s0;
+        sumy += yl.s1;
+        sumy += yl.s2;
+        sumy += yl.s3;
+        sumy += yl.s4;
+        sumy += yl.s5;
+        sumy += yl.s6;
+        sumy += yl.s7;
+        sumy += yl.s8; yl.s8 /= 16.f;
+        sumy += yl.s9; yl.s9 /= 16.f;
+        sumy += yl.sa; yl.sa /= 16.f;
+        sumy += yl.sb; yl.sb /= 16.f;
+        sumy += yl.sc; yl.sc /= 16.f;
+        sumy += yl.sd; yl.sd /= 16.f;
+        sumy += yl.se; yl.se /= 16.f;
+        sumy += yl.sf; yl.sf /= 16.f;
+#elif THREADS_PER_BLK == 4  // Each thread processes 1/4 block
+        yl.s0123 = *(global float4 *)(yb);
+        yl.s4567 = *(global float4 *)(yb + 16);
+
+        sumy += yl.s0;
+        sumy += yl.s1;
+        sumy += yl.s2;
+        sumy += yl.s3;
+        sumy += yl.s4; yl.s4 /= 16.f;
+        sumy += yl.s5; yl.s5 /= 16.f;
+        sumy += yl.s6; yl.s6 /= 16.f;
+        sumy += yl.s7; yl.s7 /= 16.f;
+#elif THREADS_PER_BLK == 8  // Each thread processes 1/8 block
+        yl.s01 = *(global float2 *)(yb);
+        yl.s23 = *(global float2 *)(yb + 16);
+
+        sumy += yl.s0;
+        sumy += yl.s1;
+        sumy += yl.s2; yl.s2 /= 16.f;
+        sumy += yl.s3; yl.s3 /= 16.f;
+#endif
+
+        sumf.s0 += block_q_4_0_dot_y_flat(q_blk_0, *(d + ib + 0*nb), sumy, yl);
+        sumf.s1 += block_q_4_0_dot_y_flat(q_blk_1, *(d + ib + 1*nb), sumy, yl);
+#if N_DST == 4 || N_DST == 8 || N_DST == 16
+        sumf.s2 += block_q_4_0_dot_y_flat(q_blk_2, *(d + ib + 2*nb), sumy, yl);
+        sumf.s3 += block_q_4_0_dot_y_flat(q_blk_3, *(d + ib + 3*nb), sumy, yl);
+#endif
+#if N_DST == 8 || N_DST == 16
+        sumf.s4 += block_q_4_0_dot_y_flat(q_blk_4, *(d + ib + 4*nb), sumy, yl);
+        sumf.s5 += block_q_4_0_dot_y_flat(q_blk_5, *(d + ib + 5*nb), sumy, yl);
+        sumf.s6 += block_q_4_0_dot_y_flat(q_blk_6, *(d + ib + 6*nb), sumy, yl);
+        sumf.s7 += block_q_4_0_dot_y_flat(q_blk_7, *(d + ib + 7*nb), sumy, yl);
+#endif
+
+        yb += QK4_0 * (N_SIMDWIDTH/THREADS_PER_BLK);
+    }
+
+    SUM_TY tot = (SUM_TY)(
+          sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1)
+#if N_DST == 4 || N_DST == 8 || N_DST == 16
+        , sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
+#endif
+#if N_DST == 8 || N_DST == 16
+        , sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5)
+        , sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7)
+#endif
+    );
+
+    if (get_sub_group_local_id() == 0) {
+        if (first_row + 0 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
+        }
+        if (first_row + 1 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
+        }
+#if N_DST == 4 || N_DST == 8 || N_DST == 16
+        if (first_row + 2 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
+        }
+        if (first_row + 3 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
+        }
+#endif
+#if N_DST == 8 || N_DST == 16
+        if (first_row + 4 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;
+        }
+        if (first_row + 5 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;
+        }
+        if (first_row + 6 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;
+        }
+        if (first_row + 7 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;
+        }
+#endif
+    }
+}
+
+//------------------------------------------------------------------------------
+// Using image1d_buffer_t
+
+#if defined(cl_qcom_subgroup_shuffle)
+#pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable
+float qcom_sub_group_reduce_add(float sum) {
+    sum += qcom_sub_group_shuffle_down(sum, 32, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f);
+    sum += qcom_sub_group_shuffle_down(sum, 16, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f);
+    sum += qcom_sub_group_shuffle_down(sum,  8, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f);
+    sum += qcom_sub_group_shuffle_down(sum,  4, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f);
+    sum += qcom_sub_group_shuffle_down(sum,  2, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f);
+    sum += qcom_sub_group_shuffle_down(sum,  1, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f);
+    return sum;
+}
+#define sub_group_reduce_add qcom_sub_group_reduce_add
+#else
+#define sub_group_reduce_add sub_group_reduce_add
+#endif
+
+#undef THREADS_PER_BLK
+#undef N_DST
+#undef N_SIMDGROUP
+#undef N_SIMDWIDTH
+
+#ifdef INTEL_GPU
+#define THREADS_PER_BLK 4   // Number of threads per block, or each thread process 1/THREADS_PER_BLK of a block
+#define N_DST           4
+#define N_SIMDGROUP     1
+#define N_SIMDWIDTH     16
+#elif defined (ADRENO_GPU)
+#define THREADS_PER_BLK 4
+#define N_DST           4
+#define N_SIMDGROUP     1
+#define N_SIMDWIDTH     64
+#endif
+
+#if THREADS_PER_BLK == 2                // Each thread processes 1/2 block
+#   define ACT_TY                       float16
+#   define Q_BLK_LD_TY                  float2
+#   define EXTRACT_BLK_DATA(tmp, part)  *((float2*)&tmp + part)
+#   define block_q_4_0_dot_y_flat       block_q_4_0_dot_y_flat_v8
+#elif THREADS_PER_BLK == 4              // Each thread processes 1/4 block
+#   define ACT_TY                       float8
+#   define Q_BLK_LD_TY                  float
+#   define EXTRACT_BLK_DATA(tmp, part)  *((float*)&tmp + part)
+#   define block_q_4_0_dot_y_flat       block_q_4_0_dot_y_flat_v4
+#elif THREADS_PER_BLK == 8              // Each thread processes 1/8 block
+#   define ACT_TY                       float4
+#   define Q_BLK_LD_TY                  half
+#   define EXTRACT_BLK_DATA(tmp, part)  *((half*)&tmp + part)
+#   define block_q_4_0_dot_y_flat       block_q_4_0_dot_y_flat_v2
+#endif
+
+#define BTYES_PER_THREAD_IN_BLK         (QK4_0/2/THREADS_PER_BLK)
+
+#if N_DST == 2
+#   define  SUM_TY                      float2
+#elif N_DST == 4
+#   define  SUM_TY                      float4
+#elif N_DST == 8
+#   define  SUM_TY                      float8
+#elif N_DST == 16
+#   define  SUM_TY                      float16
+#endif
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mat_q4_0_f32_flat_img_v0(
+        read_only image1d_buffer_t src0_q,
+        read_only image1d_buffer_t src0_d,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    const int nb = ne00/QK4_0;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+
+    int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    // The number of scales is the same as the number of blocks.
+    ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    // Each block contains QK4_0/2 uchars, hence offset for qs is as follows.
+    ulong offset0_q = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    global float * y = (global float *) src1   + r1*ne10 + im*ne00*ne1;
+
+    int ix = get_sub_group_local_id()/THREADS_PER_BLK;
+    int il = get_sub_group_local_id()%THREADS_PER_BLK;
+
+    global float * yb = y + ix*QK4_0 + BTYES_PER_THREAD_IN_BLK*il;
+
+    // Registers for caching activation
+    ACT_TY yl = 0.f;
+
+    // Registers for caching quants
+    Q_BLK_LD_TY q_blk_0 = 0, q_blk_1 = 0;
+#if N_DST == 4 || N_DST == 8 || N_DST == 16
+    Q_BLK_LD_TY q_blk_2 = 0, q_blk_3 = 0;
+#endif
+#if N_DST == 8 || N_DST == 16
+    Q_BLK_LD_TY q_blk_4 = 0, q_blk_5 = 0, q_blk_6 = 0, q_blk_7 = 0;
+#endif
+
+    // Partial sum
+    SUM_TY sumf = 0.f;
+
+    for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/THREADS_PER_BLK) {
+        float sumy = 0.f;;
+
+        float4 tmp;
+        tmp = read_imagef(src0_q, offset0_q + ib + 0*nb);
+        q_blk_0 = EXTRACT_BLK_DATA(tmp, il);
+        tmp = read_imagef(src0_q, offset0_q + ib + 1*nb);
+        q_blk_1 = EXTRACT_BLK_DATA(tmp, il);
+#if N_DST == 4 || N_DST == 8 || N_DST == 16
+        tmp = read_imagef(src0_q, offset0_q + ib + 2*nb);
+        q_blk_2 = EXTRACT_BLK_DATA(tmp, il);
+        tmp = read_imagef(src0_q, offset0_q + ib + 3*nb);
+        q_blk_3 = EXTRACT_BLK_DATA(tmp, il);
+#endif
+#if N_DST == 8 || N_DST == 16
+        tmp = read_imagef(src0_q, offset0_q + ib + 4*nb);
+        q_blk_4 = EXTRACT_BLK_DATA(tmp, il);
+        tmp = read_imagef(src0_q, offset0_q + ib + 5*nb);
+        q_blk_5 = EXTRACT_BLK_DATA(tmp, il);
+        tmp = read_imagef(src0_q, offset0_q + ib + 6*nb);
+        q_blk_6 = EXTRACT_BLK_DATA(tmp, il);
+        tmp = read_imagef(src0_q, offset0_q + ib + 7*nb);
+        q_blk_7 = EXTRACT_BLK_DATA(tmp, il);
+#endif
+
+        // Load activation
+#if THREADS_PER_BLK == 2    // Each thread processes 1/2 block
+        yl.s01234567 = *(global float8 *)(yb);
+        yl.s89abcdef = *(global float8 *)(yb + 16);
+
+        sumy += yl.s0;
+        sumy += yl.s1;
+        sumy += yl.s2;
+        sumy += yl.s3;
+        sumy += yl.s4;
+        sumy += yl.s5;
+        sumy += yl.s6;
+        sumy += yl.s7;
+        sumy += yl.s8; yl.s8 /= 16.f;
+        sumy += yl.s9; yl.s9 /= 16.f;
+        sumy += yl.sa; yl.sa /= 16.f;
+        sumy += yl.sb; yl.sb /= 16.f;
+        sumy += yl.sc; yl.sc /= 16.f;
+        sumy += yl.sd; yl.sd /= 16.f;
+        sumy += yl.se; yl.se /= 16.f;
+        sumy += yl.sf; yl.sf /= 16.f;
+#elif THREADS_PER_BLK == 4  // Each thread processes 1/4 block
+        yl.s0123 = *(global float4 *)(yb);
+        yl.s4567 = *(global float4 *)(yb + 16);
+
+        sumy += yl.s0;
+        sumy += yl.s1;
+        sumy += yl.s2;
+        sumy += yl.s3;
+        sumy += yl.s4; yl.s4 /= 16.f;
+        sumy += yl.s5; yl.s5 /= 16.f;
+        sumy += yl.s6; yl.s6 /= 16.f;
+        sumy += yl.s7; yl.s7 /= 16.f;
+#elif THREADS_PER_BLK == 8  // Each thread processes 1/8 block
+        yl.s01 = *(global float2 *)(yb);
+        yl.s23 = *(global float2 *)(yb + 16);
+
+        sumy += yl.s0;
+        sumy += yl.s1;
+        sumy += yl.s2; yl.s2 /= 16.f;
+        sumy += yl.s3; yl.s3 /= 16.f;
+#endif
+
+        sumf.s0 += block_q_4_0_dot_y_flat(q_blk_0, read_imageh(src0_d, offset0_d + ib + 0*nb).s0, sumy, yl);
+        sumf.s1 += block_q_4_0_dot_y_flat(q_blk_1, read_imageh(src0_d, offset0_d + ib + 1*nb).s0, sumy, yl);
+#if N_DST == 4 || N_DST == 8 || N_DST == 16
+        sumf.s2 += block_q_4_0_dot_y_flat(q_blk_2, read_imageh(src0_d, offset0_d + ib + 2*nb).s0, sumy, yl);
+        sumf.s3 += block_q_4_0_dot_y_flat(q_blk_3, read_imageh(src0_d, offset0_d + ib + 3*nb).s0, sumy, yl);
+#endif
+#if N_DST == 8 || N_DST == 16
+        sumf.s4 += block_q_4_0_dot_y_flat(q_blk_4, read_imageh(src0_d, offset0_d + ib + 4*nb).s0, sumy, yl);
+        sumf.s5 += block_q_4_0_dot_y_flat(q_blk_5, read_imageh(src0_d, offset0_d + ib + 5*nb).s0, sumy, yl);
+        sumf.s6 += block_q_4_0_dot_y_flat(q_blk_6, read_imageh(src0_d, offset0_d + ib + 6*nb).s0, sumy, yl);
+        sumf.s7 += block_q_4_0_dot_y_flat(q_blk_7, read_imageh(src0_d, offset0_d + ib + 7*nb).s0, sumy, yl);
+#endif
+
+        yb += QK4_0 * (N_SIMDWIDTH/THREADS_PER_BLK);
+    }
+
+    SUM_TY tot = (SUM_TY)(
+          sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1)
+#if N_DST == 4 || N_DST == 8 || N_DST == 16
+        , sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
+#endif
+#if N_DST == 8 || N_DST == 16
+        , sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5)
+        , sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7)
+#endif
+    );
+
+    if (get_sub_group_local_id() == 0) {
+        if (first_row + 0 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
+        }
+        if (first_row + 1 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
+        }
+#if N_DST == 4 || N_DST == 8 || N_DST == 16
+        if (first_row + 2 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
+        }
+        if (first_row + 3 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
+        }
+#endif
+#if N_DST == 8 || N_DST == 16
+        if (first_row + 4 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4;
+        }
+        if (first_row + 5 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5;
+        }
+        if (first_row + 6 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6;
+        }
+        if (first_row + 7 < ne01) {
+            dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7;
+        }
+#endif
+    }
+}
+
+//------------------------------------------------------------------------------
+// kernel_mul_mv_q6_K_f32
+//------------------------------------------------------------------------------
+
+#undef N_DST
+#undef N_SIMDGROUP
+#undef N_SIMDWIDTH
+
+#ifdef INTEL_GPU
+#define N_DST 1 // number of rows each SIMD group works on
+#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
+#define N_SIMDWIDTH 16 // SIMD group size
+#elif defined (ADRENO_GPU)
+#define N_DST 1
+#define N_SIMDGROUP 2
+#define N_SIMDWIDTH 64
+#endif
+
+#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes
+
+#ifdef INTEL_GPU
+REQD_SUBGROUP_SIZE_16
+#elif defined (ADRENO_GPU)
+REQD_SUBGROUP_SIZE_64
+#endif
+kernel void kernel_mul_mv_q6_K_f32(
+        global void * src0,
+        ulong offset0,
+        global float * src1,
+        ulong offset1,
+        global float * dst,
+        ulong offsetd,
+        int ne00,
+        int ne01,
+        int ne02,
+        int ne10,
+        int ne12,
+        int ne0,
+        int ne1,
+        int r2,
+        int r3
+) {
+    src0 = (global void*)((global char*)src0 + offset0);
+    src1 = (global float*)((global char*)src1 + offset1);
+    dst = (global float*)((global char*)dst + offsetd);
+
+    uchar kmask1 = 0x03;
+    uchar kmask2 = 0x0C;
+    uchar kmask3 = 0x30;
+    uchar kmask4 = 0xC0;
+
+    int nb = ne00/QK_K;
+
+    int r0 = get_group_id(0);
+    int r1 = get_group_id(1);
+    int im = get_group_id(2);
+
+    int row = N_SIMDGROUP * r0 + get_sub_group_id();
+
+    int i12 = im%ne12;
+    int i13 = im/ne12;
+
+    ulong offset_src0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+    global block_q6_K * x = (global block_q6_K *) src0 + row*nb + offset_src0;
+    global float      * yy = (global float     *) src1 + r1*ne10 + im*ne00*ne1;
+
+    float sumf = 0;
+
+    // For Q6_K quantization, 16 values forms a subblock, 16 subblock forms a
+    // block. Values in a subblock shares a scale that is quantized with 8 bits;
+    // the entire block shares a single floating point scale.
+    // For work distribution, each thread processes a subblock (16 weights), hence
+    // 16 threads process a (super) block -- a subgroup thus handles SIMDWIDTH/16
+    // (super) blocks -- this is the block stride.
+    // The 16 threads that process a (super) block are split into 2 portions, each has
+    // 8 threads; each portion works on 8 subblocks.
+    // For subgroup of 16 threads, the entire subgroup works on a single (super) block
+    // before moving to the next (super) block. Thread0 - thread7 work on the
+    // first 8 subblocks; thread8 - thread15 works on the last 8 subblocks.
+    // Thread0 - thread3 work on subblocks 0, 2, 4, 6; thread4 - thread7 work on
+    // subblocks 1, 3, 5, 7. Each thread does not work on an entire subblock, but
+    // works on a total of 16 weight values.
+    int tid  = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0
+    int ix   = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1
+    int ip   = tid/8;   // first or second half of (super) block (0 or 1)
+    int il   = tid%8;   // each half has 8 parts, one per scale
+    int n    = 4;       // 4 scales at a time (and 4 sums)
+    int l0   = n*il;    // offset into half-block, 0..28
+    int is   = 8*ip + l0/16; // 0, 1, 8, 9
+
+    int y_offset = 128*ip + l0;
+    int q_offset_l = 64*ip + l0;
+    int q_offset_h = 32*ip + l0;
+
+    for (int i = ix; i < nb; i += BLOCK_STRIDE) {
+
+        global uint8_t * q1 = x[i].ql + q_offset_l;
+        global uint8_t * q2 = q1 + QK_K/8;
+        global uint8_t * qh = x[i].qh + q_offset_h;
+        global int8_t  * sc = x[i].scales + is;
+
+        global float * y = yy + i * QK_K + y_offset;
+
+        float dall = x[i].d;
+
+        float4 sums = {0.f, 0.f, 0.f, 0.f};
+
+        sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & kmask1) << 4)) - 32.f);
+        sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & kmask2) << 2)) - 32.f);
+        sums.s2 += y[0+64] * ((float)((q1[0]  >> 4) | ((qh[0] & kmask3) << 0)) - 32.f);
+        sums.s3 += y[0+96] * ((float)((q2[0]  >> 4) | ((qh[0] & kmask4) >> 2)) - 32.f);
+
+        sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & kmask1) << 4)) - 32.f);
+        sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & kmask2) << 2)) - 32.f);
+        sums.s2 += y[1+64] * ((float)((q1[1]  >> 4) | ((qh[1] & kmask3) << 0)) - 32.f);
+        sums.s3 += y[1+96] * ((float)((q2[1]  >> 4) | ((qh[1] & kmask4) >> 2)) - 32.f);
+
+        sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & kmask1) << 4)) - 32.f);
+        sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & kmask2) << 2)) - 32.f);
+        sums.s2 += y[2+64] * ((float)((q1[2]  >> 4) | ((qh[2] & kmask3) << 0)) - 32.f);
+        sums.s3 += y[2+96] * ((float)((q2[2]  >> 4) | ((qh[2] & kmask4) >> 2)) - 32.f);
+
+        sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & kmask1) << 4)) - 32.f);
+        sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & kmask2) << 2)) - 32.f);
+        sums.s2 += y[3+64] * ((float)((q1[3]  >> 4) | ((qh[3] & kmask3) << 0)) - 32.f);
+        sums.s3 += y[3+96] * ((float)((q2[3]  >> 4) | ((qh[3] & kmask4) >> 2)) - 32.f);
+
+        sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]);
+    }
+
+    float tot = sub_group_reduce_add(sumf);
+    if (get_sub_group_local_id() == 0) {
+        dst[r1*ne0 + im*ne0*ne1 + row] = tot;
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl
new file mode 100644
index 00000000000..57768c80334
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl
@@ -0,0 +1,130 @@
+// src0_q, src0_d, src1 are transposed as a preprocessing step
+// 4-bit weights are transposed in groups of 4 (unsigned short int)
+// consider weights originally "next to each other", now "on top of each other"
+// each fiber computes a 8x4 tile of output elements
+// using unshuffled weights
+
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
+
+__attribute__((qcom_reqd_sub_group_size("full")))
+kernel void kernel_mul_mat_Ab_Bi_8x4(
+        global const ushort * src0_q,       // quantized A
+        global const half  * src0_d,        // A scales
+        __read_only image1d_buffer_t src1,  // B (1d image)
+        global float * dst,                 // C
+        int m,                              // M
+        int n,                              // N with padding
+        int k,                              // K
+        int n_no_padding                    // N without padding
+) {
+
+    int m_4 = m >> 2;
+    int n_4 = n >> 2;
+
+    int gy = get_global_id(0);
+    int gx = get_global_id(1);
+    int gx_2 = gx << 2;
+
+    half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; // 8x4 output elements
+    half8 B; // registers for activations
+    half4 dequantized_weights; // registers for dequantized weights
+    __global const ushort* weight_ptr = src0_q + gx_2; // pointer for weights
+    __global const half* scale_ptr = src0_d + gx_2; // pointer for scales
+
+    for(int i=0; i> 4) - 8) * scale.s0; // dequantize a row of the 16 weights
+        dequantized_weights.s1 = (((bits4.s1 & (0x00F0)) >> 4) - 8) * scale.s1;
+        dequantized_weights.s2 = (((bits4.s2 & (0x00F0)) >> 4) - 8) * scale.s2;
+        dequantized_weights.s3 = (((bits4.s3 & (0x00F0)) >> 4) - 8) * scale.s3;
+        c0 += B * dequantized_weights.s0; //vector-scalar multiplication to accumulate
+        c1 += B * dequantized_weights.s1;
+        c2 += B * dequantized_weights.s2;
+        c3 += B * dequantized_weights.s3;
+
+        // j=2
+        B.s0123 = read_imageh(src1, gy*2 + (i+2)*(n_4));
+        B.s4567 = read_imageh(src1, gy*2 + (i+2)*(n_4)+1);
+        dequantized_weights.s0 = (((bits4.s0 & (0x0F00)) >> 8) - 8) * scale.s0; // dequantize a row of the 16 weights
+        dequantized_weights.s1 = (((bits4.s1 & (0x0F00)) >> 8) - 8) * scale.s1;
+        dequantized_weights.s2 = (((bits4.s2 & (0x0F00)) >> 8) - 8) * scale.s2;
+        dequantized_weights.s3 = (((bits4.s3 & (0x0F00)) >> 8) - 8) * scale.s3;
+        c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate
+        c1 += B * dequantized_weights.s1;
+        c2 += B * dequantized_weights.s2;
+        c3 += B * dequantized_weights.s3;
+
+        // j=3
+        B.s0123 = read_imageh(src1, gy*2 + (i+3)*(n_4));
+        B.s4567 = read_imageh(src1, gy*2 + (i+3)*(n_4)+1);
+        dequantized_weights.s0 = (((bits4.s0 & (0xF000)) >> 12) - 8) * scale.s0; // dequantize a row of the 16 weights
+        dequantized_weights.s1 = (((bits4.s1 & (0xF000)) >> 12) - 8) * scale.s1;
+        dequantized_weights.s2 = (((bits4.s2 & (0xF000)) >> 12) - 8) * scale.s2;
+        dequantized_weights.s3 = (((bits4.s3 & (0xF000)) >> 12) - 8) * scale.s3;
+        c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate
+        c1 += B * dequantized_weights.s1;
+        c2 += B * dequantized_weights.s2;
+        c3 += B * dequantized_weights.s3;
+    }
+
+    int idx = (gy<<3)*m + (gx<<2); // vectorized store 16 elements
+
+    // conditional check if store is to a valid location. Required when N is not a multiple of 8
+    // if statements allow registers to be reused for each store
+    // provides a performance boost due to reduced register footprint, which increases number of concurrent waves
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);
+        idx += m;
+    }
+    if(idx+3 < m*n_no_padding){
+        vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);
+    }
+}
diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl
new file mode 100644
index 00000000000..d59a0c05ddf
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl
@@ -0,0 +1,32 @@
+// 16-bit transpose, loading/storing an 8x8 tile of elements
+
+kernel void kernel_transpose_16(
+    __read_only image1d_buffer_t input,
+    __write_only image1d_buffer_t output,
+    const uint rows,
+    const uint cols
+) {
+
+    const int i = get_global_id(0);
+    const int j = get_global_id(1);
+    const int i_3 = i<<3;
+    const int j_3 = j<<3;
+
+    ushort8 temp0 = as_ushort8(read_imagef(input, (j_3+0)*cols+i));
+    ushort8 temp1 = as_ushort8(read_imagef(input, (j_3+1)*cols+i));
+    ushort8 temp2 = as_ushort8(read_imagef(input, (j_3+2)*cols+i));
+    ushort8 temp3 = as_ushort8(read_imagef(input, (j_3+3)*cols+i));
+    ushort8 temp4 = as_ushort8(read_imagef(input, (j_3+4)*cols+i));
+    ushort8 temp5 = as_ushort8(read_imagef(input, (j_3+5)*cols+i));
+    ushort8 temp6 = as_ushort8(read_imagef(input, (j_3+6)*cols+i));
+    ushort8 temp7 = as_ushort8(read_imagef(input, (j_3+7)*cols+i));
+
+    write_imagef(output, (i_3+0)*rows+j, as_float4((ushort8)(temp0.s0, temp1.s0, temp2.s0, temp3.s0, temp4.s0, temp5.s0, temp6.s0, temp7.s0)));
+    write_imagef(output, (i_3+1)*rows+j, as_float4((ushort8)(temp0.s1, temp1.s1, temp2.s1, temp3.s1, temp4.s1, temp5.s1, temp6.s1, temp7.s1)));
+    write_imagef(output, (i_3+2)*rows+j, as_float4((ushort8)(temp0.s2, temp1.s2, temp2.s2, temp3.s2, temp4.s2, temp5.s2, temp6.s2, temp7.s2)));
+    write_imagef(output, (i_3+3)*rows+j, as_float4((ushort8)(temp0.s3, temp1.s3, temp2.s3, temp3.s3, temp4.s3, temp5.s3, temp6.s3, temp7.s3)));
+    write_imagef(output, (i_3+4)*rows+j, as_float4((ushort8)(temp0.s4, temp1.s4, temp2.s4, temp3.s4, temp4.s4, temp5.s4, temp6.s4, temp7.s4)));
+    write_imagef(output, (i_3+5)*rows+j, as_float4((ushort8)(temp0.s5, temp1.s5, temp2.s5, temp3.s5, temp4.s5, temp5.s5, temp6.s5, temp7.s5)));
+    write_imagef(output, (i_3+6)*rows+j, as_float4((ushort8)(temp0.s6, temp1.s6, temp2.s6, temp3.s6, temp4.s6, temp5.s6, temp6.s6, temp7.s6)));
+    write_imagef(output, (i_3+7)*rows+j, as_float4((ushort8)(temp0.s7, temp1.s7, temp2.s7, temp3.s7, temp4.s7, temp5.s7, temp6.s7, temp7.s7)));
+}
diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl
new file mode 100644
index 00000000000..914ec0193e7
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl
@@ -0,0 +1,25 @@
+// 32-bit transpose, loading/storing a 4x4 tile of elements
+
+kernel void kernel_transpose_32(
+    __read_only image1d_buffer_t input,
+    __write_only image1d_buffer_t output,
+    const uint rows,
+    const uint cols
+) {
+
+    const int i = get_global_id(0);
+    const int j = get_global_id(1);
+    const int i_2 = i<<2;
+    const int j_2 = j<<2;
+
+    float4 temp0 = read_imagef(input, (j_2+0)*cols+i);
+    float4 temp1 = read_imagef(input, (j_2+1)*cols+i);
+    float4 temp2 = read_imagef(input, (j_2+2)*cols+i);
+    float4 temp3 = read_imagef(input, (j_2+3)*cols+i);
+
+    write_imagef(output, (i_2+0)*rows+j, (float4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0));
+    write_imagef(output, (i_2+1)*rows+j, (float4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1));
+    write_imagef(output, (i_2+2)*rows+j, (float4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2));
+    write_imagef(output, (i_2+3)*rows+j, (float4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));
+
+}
diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl
new file mode 100644
index 00000000000..d3bd1fabb76
--- /dev/null
+++ b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl
@@ -0,0 +1,35 @@
+// 32-bit transpose, loading/storing a 4x4 tile of elements
+// Only used for activations
+// converts to FP16
+// also adds zero padding for non multiple of 8 prompt lengths
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+kernel void kernel_transpose_32_16(__read_only image1d_buffer_t input, __write_only image1d_buffer_t output, const uint rows, const uint cols, const uint padded_rows) {
+
+    const int i = get_global_id(0);
+    const int j = get_global_id(1);
+    const int i_2 = i<<2;
+    const int j_2 = j<<2;
+    half4 temp0 = {0,0,0,0}; // initialize outputs to 0
+    half4 temp1 = {0,0,0,0};
+    half4 temp2 = {0,0,0,0};
+    half4 temp3 = {0,0,0,0};
+
+    if((j_2+0)*cols+i*4+3 < rows*cols*16){ // only load from a valid location. Otherwise keep register data as 0
+        temp0 = read_imageh(input, (j_2+0)*cols+i);
+    }
+    if((j_2+1)*cols+i*4+3 < rows*cols*16){
+        temp1 = read_imageh(input, (j_2+1)*cols+i);
+    }
+    if((j_2+2)*cols+i*4+3 < rows*cols*16){
+        temp2 = read_imageh(input, (j_2+2)*cols+i);
+    }
+    if((j_2+3)*cols+i*4+3 < rows*cols*16){
+        temp3 = read_imageh(input, (j_2+3)*cols+i);
+    }
+
+    write_imageh(output, (i_2+0)*padded_rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); // no conditionals for output, includes zero padding
+    write_imageh(output, (i_2+1)*padded_rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1));
+    write_imageh(output, (i_2+2)*padded_rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2));
+    write_imageh(output, (i_2+3)*padded_rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));
+}
diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp
index 43108242639..63da2b86b1b 100644
--- a/ggml/src/ggml-rpc/ggml-rpc.cpp
+++ b/ggml/src/ggml-rpc/ggml-rpc.cpp
@@ -27,15 +27,6 @@
 #endif
 #include 
 
-#define UNUSED GGML_UNUSED
-
-#define GGML_DEBUG 0
-#if (GGML_DEBUG >= 1)
-#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
-#else
-#define GGML_PRINT_DEBUG(...)
-#endif
-
 #ifdef _WIN32
 typedef SOCKET sockfd_t;
 using ssize_t = __int64;
@@ -93,9 +84,23 @@ enum rpc_cmd {
     RPC_CMD_COPY_TENSOR,
     RPC_CMD_GRAPH_COMPUTE,
     RPC_CMD_GET_DEVICE_MEMORY,
+    RPC_CMD_INIT_TENSOR,
+    RPC_CMD_GET_ALLOC_SIZE,
     RPC_CMD_COUNT,
 };
 
+struct rpc_msg_get_alloc_size_req {
+    rpc_tensor tensor;
+};
+
+struct rpc_msg_get_alloc_size_rsp {
+    uint64_t alloc_size;
+};
+
+struct rpc_msg_init_tensor_req {
+    rpc_tensor tensor;
+};
+
 struct rpc_msg_alloc_buffer_req {
     uint64_t size;
 };
@@ -397,7 +402,7 @@ static std::shared_ptr get_socket(const std::string & endpoint) {
         initialized = true;
     }
 #else
-    UNUSED(initialized);
+    GGML_UNUSED(initialized);
 #endif
     auto sock = socket_connect(host.c_str(), port);
     if (sock == nullptr) {
@@ -461,10 +466,18 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
 }
 
 static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
-    UNUSED(buffer);
-    if (ggml_is_quantized(tensor->type)) {
-        // TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized
-        GGML_ASSERT(tensor->ne[0] % 512 == 0 && "unsupported quantized tensor");
+    ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
+
+    // CUDA backend on the server pads everything to 512 due to CUDA limitations.
+    // Due to bandwidth constraints, we only call the server init tensor functions if necessary.
+    // In particular, only quantized tensors need padding
+    if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
+        rpc_msg_init_tensor_req request;
+
+        request.tensor = serialize_tensor(tensor);
+
+        bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0);
+        GGML_ASSERT(status);
     }
 }
 
@@ -577,8 +590,23 @@ static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
 }
 
 static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
-    UNUSED(buft);
-    return ggml_nbytes(tensor);
+    // See comments in init_tensor.
+    if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
+        ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
+        auto sock = get_socket(buft_ctx->endpoint);
+
+        rpc_msg_get_alloc_size_req request;
+
+        request.tensor = serialize_tensor(tensor);
+
+        rpc_msg_get_alloc_size_rsp response;
+        bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
+        GGML_ASSERT(status);
+
+        return response.alloc_size;
+    } else {
+        return ggml_nbytes(tensor);
+    }
 }
 
 static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
@@ -603,7 +631,7 @@ static void ggml_backend_rpc_free(ggml_backend_t backend) {
 }
 
 static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
-    UNUSED(backend);
+    GGML_UNUSED(backend);
     // this is no-op because we don't have any async operations
 }
 
@@ -757,6 +785,8 @@ class rpc_server {
     bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response);
     bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
     bool graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response);
+    bool init_tensor(const rpc_msg_init_tensor_req & request);
+    bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
 
 private:
     ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
@@ -770,6 +800,36 @@ class rpc_server {
     std::unordered_set buffers;
 };
 
+bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
+    ggml_backend_buffer_type_t buft;
+    struct ggml_init_params params {
+        /*.mem_size   =*/ ggml_tensor_overhead(),
+        /*.mem_buffer =*/ NULL,
+        /*.no_alloc   =*/ true,
+    };
+
+    struct ggml_context * ctx = ggml_init(params);
+    ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
+
+    if (tensor == nullptr) {
+        GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
+        ggml_free(ctx);
+        return false;
+    }
+
+    if (tensor->buffer == nullptr) {
+        //No buffer allocated.
+        buft = ggml_backend_get_default_buffer_type(backend);
+    } else {
+        buft = tensor->buffer->buft;
+    }
+
+    response.alloc_size = ggml_backend_buft_get_alloc_size(buft,tensor);
+
+    ggml_free(ctx);
+    return true;
+}
+
 void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
     ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
     ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
@@ -781,7 +841,7 @@ void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_
         GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
         buffers.insert(buffer);
     } else {
-        GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
+        GGML_LOG_ERROR("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
     }
 }
 
@@ -803,7 +863,7 @@ bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rp
     GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
     ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr);
     if (buffers.find(buffer) == buffers.end()) {
-        GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
+        GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
         return false;
     }
     void * base = ggml_backend_buffer_get_base(buffer);
@@ -815,7 +875,7 @@ bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
     GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
     ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr);
     if (buffers.find(buffer) == buffers.end()) {
-        GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
+        GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
         return false;
     }
     ggml_backend_buffer_free(buffer);
@@ -827,7 +887,7 @@ bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
     GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
     ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr);
     if (buffers.find(buffer) == buffers.end()) {
-        GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
+        GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
         return false;
     }
     ggml_backend_buffer_clear(buffer, request.value);
@@ -883,7 +943,7 @@ bool rpc_server::set_tensor(const std::vector & input) {
     struct ggml_context * ctx = ggml_init(params);
     ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
     if (tensor == nullptr) {
-        GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
+        GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
         ggml_free(ctx);
         return false;
     }
@@ -905,6 +965,40 @@ bool rpc_server::set_tensor(const std::vector & input) {
     return true;
 }
 
+bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
+    struct ggml_init_params params {
+        /*.mem_size   =*/ ggml_tensor_overhead(),
+        /*.mem_buffer =*/ NULL,
+        /*.no_alloc   =*/ true,
+    };
+    struct ggml_context * ctx = ggml_init(params);
+    ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
+    if (tensor == nullptr) {
+        GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n");
+        ggml_free(ctx);
+        return false;
+    }
+
+    // Call the backend's buffer_init_tensor function
+    ggml_backend_buffer_t buffer = tensor->buffer;
+    if (buffer && buffer->iface.init_tensor) {
+        buffer->iface.init_tensor(buffer, tensor);
+    } else {
+        GGML_LOG_ERROR("Null buffer for tensor passed to init_tensor function\n");
+    }
+
+    if (tensor->extra != nullptr) {
+        // This pointer can either be passed around client/server, or probably better stored server-side and kept track of.
+        // Currently unimplemented.
+        GGML_LOG_ERROR("tensor->extra populated by the backend, this is currently unsupported.\n");
+        ggml_free(ctx);
+        return false;
+    }
+
+    ggml_free(ctx);
+    return true;
+}
+
 bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response) {
     struct ggml_init_params params {
         /*.mem_size   =*/ ggml_tensor_overhead(),
@@ -914,7 +1008,7 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<
     struct ggml_context * ctx = ggml_init(params);
     ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
     if (tensor == nullptr) {
-        GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
+        GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
         ggml_free(ctx);
         return false;
     }
@@ -948,7 +1042,7 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
     ggml_tensor * src = deserialize_tensor(ctx, &request.src);
     ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
     if (src == nullptr || dst == nullptr) {
-        GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__);
+        GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__);
         ggml_free(ctx);
         return false;
     }
@@ -1058,6 +1152,18 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
                 }
                 break;
             }
+            case RPC_CMD_GET_ALLOC_SIZE: {
+                rpc_msg_get_alloc_size_req request;
+                if (!recv_msg(sockfd, &request, sizeof(request))) {
+                    return;
+                }
+                rpc_msg_get_alloc_size_rsp response;
+                server.get_alloc_size(request, response);
+                if (!send_msg(sockfd, &response, sizeof(response))) {
+                    return;
+                }
+                break;
+            }
             case RPC_CMD_GET_ALIGNMENT: {
                 if (!recv_msg(sockfd, nullptr, 0)) {
                     return;
@@ -1133,6 +1239,19 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
                 }
                 break;
             }
+            case RPC_CMD_INIT_TENSOR: {
+                rpc_msg_init_tensor_req request;
+                if (!recv_msg(sockfd, &request,sizeof(request))) {
+                    return;
+                }
+                if (!server.init_tensor(request)) {
+                    return;
+                }
+                if (!send_msg(sockfd, nullptr, 0)) {
+                    return;
+                }
+                break;
+            }
             case RPC_CMD_GET_TENSOR: {
                 rpc_msg_get_tensor_req request;
                 if (!recv_msg(sockfd, &request, sizeof(request))) {
@@ -1257,14 +1376,14 @@ static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t *
 
     ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
 
-    UNUSED(dev);
+    GGML_UNUSED(dev);
 }
 
 static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
     // TODO: obtain value from the server
     return GGML_BACKEND_DEVICE_TYPE_GPU;
 
-    UNUSED(dev);
+    GGML_UNUSED(dev);
 }
 
 static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
@@ -1285,7 +1404,7 @@ static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const
 
     return ggml_backend_rpc_init(ctx->endpoint.c_str());
 
-    UNUSED(params);
+    GGML_UNUSED(params);
 }
 
 static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
@@ -1293,12 +1412,12 @@ static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_b
 
     return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
 
-    UNUSED(dev);
+    GGML_UNUSED(dev);
 }
 
 static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
-    UNUSED(dev);
-    UNUSED(op);
+    GGML_UNUSED(dev);
+    GGML_UNUSED(op);
     //TODO: call the remote backend and cache the results
     return true;
 }
@@ -1335,20 +1454,20 @@ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
 static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
     return "RPC";
 
-    UNUSED(reg);
+    GGML_UNUSED(reg);
 }
 
 static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
     return 0;
 
-    UNUSED(reg);
+    GGML_UNUSED(reg);
 }
 
 static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
     GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
 
-    UNUSED(reg);
-    UNUSED(index);
+    GGML_UNUSED(reg);
+    GGML_UNUSED(index);
 }
 
 static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
@@ -1357,7 +1476,7 @@ static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const ch
     }
     return NULL;
 
-    UNUSED(reg);
+    GGML_UNUSED(reg);
 }
 
 static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {
diff --git a/ggml/src/ggml-sycl/common.cpp b/ggml/src/ggml-sycl/common.cpp
index 88314a5cd73..022e7b7637b 100644
--- a/ggml/src/ggml-sycl/common.cpp
+++ b/ggml/src/ggml-sycl/common.cpp
@@ -51,6 +51,10 @@ void ggml_sycl_host_free(void* ptr) try {
   std::exit(1);
 }
 
+bool gpu_has_xmx(sycl::device &dev) {
+    return dev.has(sycl::aspect::ext_intel_matrix);
+}
+
 int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) {
   const int64_t max_range = std::numeric_limits::max();
   int64_t sycl_down_blk_size = block_size;
diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp
index 62b4cea3ada..e9500f3a168 100644
--- a/ggml/src/ggml-sycl/common.hpp
+++ b/ggml/src/ggml-sycl/common.hpp
@@ -662,6 +662,7 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t
     }
 }
 
+bool gpu_has_xmx(sycl::device &dev);
 
 void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
                                  const ggml_tensor *src1, ggml_tensor *dst,
diff --git a/ggml/src/ggml-sycl/concat.cpp b/ggml/src/ggml-sycl/concat.cpp
index a240968ad2e..d41cfd3a6ec 100644
--- a/ggml/src/ggml-sycl/concat.cpp
+++ b/ggml/src/ggml-sycl/concat.cpp
@@ -158,8 +158,9 @@ static void concat_f32_sycl_non_cont(
       });
 }
 
-void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                const ggml_tensor *src1, ggml_tensor *dst) {
+void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
+  const ggml_tensor *src0 = dst->src[0];
+  const ggml_tensor *src1 = dst->src[1];
   queue_ptr stream = ctx.stream();
 
   const int32_t dim = ((int32_t *)dst->op_params)[0];
diff --git a/ggml/src/ggml-sycl/concat.hpp b/ggml/src/ggml-sycl/concat.hpp
index 5a04feaab6b..e5cb7314c9f 100644
--- a/ggml/src/ggml-sycl/concat.hpp
+++ b/ggml/src/ggml-sycl/concat.hpp
@@ -15,7 +15,6 @@
 
 #include "common.hpp"
 
-void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                const ggml_tensor *src1, ggml_tensor *dst);
+void ggml_sycl_op_concat(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
 
 #endif // GGML_SYCL_CONCAT_HPP
diff --git a/ggml/src/ggml-sycl/conv.cpp b/ggml/src/ggml-sycl/conv.cpp
index bc4ab1ddbad..ddba601e10f 100644
--- a/ggml/src/ggml-sycl/conv.cpp
+++ b/ggml/src/ggml-sycl/conv.cpp
@@ -71,8 +71,9 @@ static void conv_transpose_1d_f32_f32_sycl(
         });
 }
 
-void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-    const ggml_tensor *src1, ggml_tensor *dst) {
+void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
+    const ggml_tensor *src0 = dst->src[0];
+    const ggml_tensor *src1 = dst->src[1];
     const float * src0_d = (const float *)src0->data;
     const float * src1_d = (const float *)src1->data;
 
diff --git a/ggml/src/ggml-sycl/conv.hpp b/ggml/src/ggml-sycl/conv.hpp
index eb20730f904..f9e60dc7580 100644
--- a/ggml/src/ggml-sycl/conv.hpp
+++ b/ggml/src/ggml-sycl/conv.hpp
@@ -15,7 +15,6 @@
 
 #include "common.hpp"
 
-void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-  const ggml_tensor *src1, ggml_tensor *dst);
+void ggml_sycl_op_conv_transpose_1d(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
 
 #endif // GGML_SYCL_CONV_HPP
diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp
index d05a51f807c..4bcd74376ea 100644
--- a/ggml/src/ggml-sycl/element_wise.cpp
+++ b/ggml/src/ggml-sycl/element_wise.cpp
@@ -882,149 +882,149 @@ inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, const ggml_tensor
 }
 
 
-void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sqrt);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sqrt);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-void ggml_sycl_sin(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sin);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sin);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-void ggml_sycl_cos(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+void ggml_sycl_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_cos);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_cos);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-void ggml_sycl_acc(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+void ggml_sycl_acc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_acc);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_acc);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_gelu);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-void ggml_sycl_silu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_silu);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_silu);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu_quick);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_gelu_quick);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_tanh);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_tanh);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+void ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_relu);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_relu);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sigmoid);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sigmoid);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardsigmoid);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_hardsigmoid);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardswish);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_hardswish);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
 
-void ggml_sycl_exp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_exp);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_exp);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-void ggml_sycl_log(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_log);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_log);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-void ggml_sycl_neg(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_neg);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_neg);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-void ggml_sycl_step(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_step);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_step);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_leaky_relu);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_leaky_relu);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sqr);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sqr);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_upscale);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_upscale);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-void ggml_sycl_pad(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pad);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_pad);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
 
 
-void ggml_sycl_add(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_add);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_add);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-void ggml_sycl_sub(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sub);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sub);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-void ggml_sycl_mul(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_mul);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_mul);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-void ggml_sycl_div(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_div);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_div);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp
index 8152edf5838..46443264505 100644
--- a/ggml/src/ggml-sycl/element_wise.hpp
+++ b/ggml/src/ggml-sycl/element_wise.hpp
@@ -25,52 +25,52 @@ static __dpct_inline__ float op_div(const float a, const float b) {
 }
 
 
-void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-void ggml_sycl_sin(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-void ggml_sycl_cos(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_sycl_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-void ggml_sycl_acc(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_sycl_acc(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-void ggml_sycl_silu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-void ggml_sycl_exp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-void ggml_sycl_log(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-void ggml_sycl_neg(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-void ggml_sycl_step(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-void ggml_sycl_pad(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-void ggml_sycl_add(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-void ggml_sycl_sub(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-void ggml_sycl_mul(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-void ggml_sycl_div(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
 #endif // GGML_SYCL_ELEMENTWISE_HPP
diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp
index 312ccfeb853..037c8093eef 100644
--- a/ggml/src/ggml-sycl/ggml-sycl.cpp
+++ b/ggml/src/ggml-sycl/ggml-sycl.cpp
@@ -54,18 +54,12 @@ static ggml_sycl_device_info ggml_sycl_init() {
     GGML_ASSERT(info.device_count <= GGML_SYCL_MAX_DEVICES);
 
     int64_t total_vram = 0;
-#if defined(GGML_SYCL_FORCE_MMQ)
-    GGML_LOG_INFO("%s: GGML_SYCL_FORCE_MMQ:   yes\n", __func__);
-#else
-    GGML_LOG_INFO("%s: GGML_SYCL_FORCE_MMQ:   no\n", __func__);
-#endif
-#if defined(SYCL_USE_XMX)
-    GGML_LOG_INFO("%s: SYCL_USE_XMX: yes\n", __func__);
-#else
-    GGML_LOG_INFO("%s: SYCL_USE_XMX: no\n", __func__);
-#endif
-    GGML_LOG_INFO("%s: found %d %s devices:\n", __func__, info.device_count, GGML_SYCL_NAME);
-
+/* This is a bit misleading;  reserved for later */
+// #if defined(SYCL_USE_XMX)
+//     GGML_LOG_INFO("%s: SYCL_USE_XMX: yes\n", __func__);
+// #else
+//     GGML_LOG_INFO("%s: SYCL_USE_XMX: no\n", __func__);
+// #endif
     for (int i = 0; i < info.device_count; ++i) {
         info.devices[i].vmm = 0;
         dpct::device_info prop;
@@ -109,11 +103,11 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
     name = std::regex_replace(name, std::regex("\\(TM\\)"), "");
 
     auto global_mem_size = prop.get_global_mem_size()/1000000;
-
-    GGML_LOG_INFO("|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(),
+    std::string xmx = gpu_has_xmx(device) ? "yes" : "no";
+    GGML_LOG_INFO("|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|%14s|\n", id, device_type.c_str(),
             name.c_str(), version.c_str(), prop.get_max_compute_units(),
             prop.get_max_work_group_size(), prop.get_max_sub_group_size(),
-            global_mem_size, device.get_info().c_str());
+            global_mem_size, device.get_info().c_str(), xmx.c_str());
 }
 
 void ggml_backend_sycl_print_sycl_devices() {
@@ -124,16 +118,16 @@ void ggml_backend_sycl_print_sycl_devices() {
 
     GGML_LOG_INFO(
         "|  |                   |                                       |      "
-        " |Max    |        |Max  |Global |                     |\n");
+        " |Max    |        |Max  |Global |                     |         XMX  |\n");
     GGML_LOG_INFO(
         "|  |                   |                                       |      "
-        " |compute|Max work|sub  |mem    |                     |\n");
+        " |compute|Max work|sub  |mem    |                     |          or  |\n");
     GGML_LOG_INFO(
         "|ID|        Device Type|                                   "
-        "Name|Version|units  |group   |group|size   |       Driver version|\n");
+        "Name|Version|units  |group   |group|size   |       Driver version| Tensor Cores |\n");
     GGML_LOG_INFO(
         "|--|-------------------|---------------------------------------|------"
-        "-|-------|--------|-----|-------|---------------------|\n");
+        "-|-------|--------|-----|-------|---------------------|--------------|\n");
 
     for (int id = 0; id < device_count; ++id) {
       sycl::device device = dpct::dev_mgr::instance().get_device(id);
@@ -164,14 +158,18 @@ static void ggml_check_sycl() try {
     static bool initialized = false;
 
     if (!initialized) {
-        GGML_LOG_INFO("[SYCL] call ggml_check_sycl\n");
+        GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
         g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
-        GGML_LOG_INFO("%s: GGML_SYCL_DEBUG: %d\n", __func__, g_ggml_sycl_debug);
-
+        GGML_LOG_INFO("GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
+#if defined(GGML_SYCL_FORCE_MMQ)
+        GGML_LOG_INFO("GGML_SYCL_FORCE_MMQ:   yes\n");
+#else
+        GGML_LOG_INFO("GGML_SYCL_FORCE_MMQ:   no\n");
+#endif
 #if defined(GGML_SYCL_F16)
-        GGML_LOG_INFO("%s: GGML_SYCL_F16: yes\n", __func__);
+        GGML_LOG_INFO("GGML_SYCL_F16: yes\n");
 #else
-        GGML_LOG_INFO("%s: GGML_SYCL_F16: no\n", __func__);
+        GGML_LOG_INFO("GGML_SYCL_F16: no\n");
 #endif
 
 /* NOT REMOVE, keep it for next optimize for XMX.
@@ -1189,7 +1187,6 @@ std::unique_ptr ggml_backend_sycl_context::new_pool_for_device(q
 /// kernels
 
 typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
-typedef void (*ggml_sycl_func_t)(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
 typedef void (*ggml_sycl_op_mul_mat_t)(
     ggml_backend_sycl_context & ctx,
     const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
@@ -3171,33 +3168,33 @@ catch (sycl::exception const &exc) {
 }
 
 
-static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_repeat);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_repeat);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_get_rows);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_get_rows);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_norm);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_norm);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rms_norm);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rms_norm);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_group_norm);
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_group_norm);
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
@@ -3572,9 +3569,10 @@ __dpct_inline__ static void k_copy_dst_from_contiguous(
     }
 }
 
-static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                 const ggml_tensor *src1,
+static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
                                  ggml_tensor *dst) try {
+    const ggml_tensor *src0 = dst->src[0];
+    const ggml_tensor *src1 = dst->src[1];
     GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers");
 
     const ggml_tensor *ids = dst->src[2];
@@ -3740,12 +3738,12 @@ catch (sycl::exception const &exc) {
   std::exit(1);
 }
 
-static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_scale);
+static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_scale);
 }
 
-static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_clamp);
+static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_clamp);
 }
 
 static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
@@ -3787,7 +3785,6 @@ static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *sr
                 ggml_type_name(src0->type), ggml_type_name(src1->type));
         GGML_ABORT("fatal error");
     }
-
     GGML_UNUSED(dst);
 }
 catch (sycl::exception const &exc) {
@@ -3796,59 +3793,52 @@ catch (sycl::exception const &exc) {
   std::exit(1);
 }
 
-static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     // TODO: why do we pass dst as src1 here?
-    ggml_sycl_cpy(ctx, src0, dst, nullptr);
-    GGML_UNUSED(src1);
+    ggml_sycl_cpy(ctx, dst->src[0], dst, nullptr);
 }
 
-static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_diag_mask_inf);
+static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_diag_mask_inf);
 }
 
-static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_soft_max);
+static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_soft_max);
 }
 
-static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rope);
+static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_rope);
 }
 
-static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pool2d);
+static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_pool2d);
 }
 
-static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_im2col);
+static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_im2col);
 }
 
-static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum);
+static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sum);
 }
 
-static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum_rows);
+static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_sum_rows);
 }
 
-static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argsort);
+static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_argsort);
 }
 
-static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argmax);
+static void ggml_sycl_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
+    ggml_sycl_op_flatten(ctx, dst->src[0], dst->src[1], dst, ggml_sycl_op_argmax);
 }
 
-static void ggml_sycl_nop(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_UNUSED(src0);
-    GGML_UNUSED(src1);
-    GGML_UNUSED(dst);
-    GGML_UNUSED(ctx);
-}
 
 void ggml_sycl_set_main_device(const int main_device) try {
     if (dpct::get_current_device_id() == static_cast (main_device)) {
@@ -3871,191 +3861,189 @@ catch (sycl::exception const &exc) {
   std::exit(1);
 }
 
-bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * tensor) {
+bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * dst) {
     if (!g_sycl_loaded) return false;
 
-    ggml_sycl_func_t func;
+    if (dst->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(dst->src[0]->buffer)) {
+        ggml_sycl_set_peer_access(dst->src[1]->ne[1], ctx.device);
+    }
 
-    switch (tensor->op) {
+    switch (dst->op) {
         case GGML_OP_ARGMAX:
-            func = ggml_sycl_argmax;
+            ggml_sycl_argmax(ctx, dst);
             break;
         case GGML_OP_CONV_TRANSPOSE_1D:
-            func = ggml_sycl_op_conv_transpose_1d;
+            ggml_sycl_op_conv_transpose_1d(ctx, dst);
             break;
         case GGML_OP_REPEAT:
-            func = ggml_sycl_repeat;
+            ggml_sycl_repeat(ctx, dst);
             break;
         case GGML_OP_GET_ROWS:
-            func = ggml_sycl_get_rows;
+            ggml_sycl_get_rows(ctx, dst);
             break;
         case GGML_OP_DUP:
-            func = ggml_sycl_dup;
+            ggml_sycl_dup(ctx, dst);
             break;
         case GGML_OP_ADD:
         case GGML_OP_ADD1: // TODO: more efficient implementation
-            func = ggml_sycl_add;
+            ggml_sycl_add(ctx, dst);
             break;
         case GGML_OP_SUB:
-            func = ggml_sycl_sub;
+            ggml_sycl_sub(ctx, dst);
             break;
         case GGML_OP_ACC:
-            func = ggml_sycl_acc;
+            ggml_sycl_acc(ctx, dst);
             break;
         case GGML_OP_MUL:
-            func = ggml_sycl_mul;
+            ggml_sycl_mul(ctx, dst);
             break;
         case GGML_OP_LOG:
-            func = ggml_sycl_log;
+            ggml_sycl_log(ctx, dst);
             break;
         case GGML_OP_DIV:
-            func = ggml_sycl_div;
+            ggml_sycl_div(ctx, dst);
             break;
         case GGML_OP_UNARY:
-            switch (ggml_get_unary_op(tensor)) {
+            switch (ggml_get_unary_op(dst)) {
                 case GGML_UNARY_OP_NEG:
-                    func = ggml_sycl_neg;
+                    ggml_sycl_neg(ctx, dst);
                     break;
                 case GGML_UNARY_OP_STEP:
-                    func = ggml_sycl_step;
+                    ggml_sycl_step(ctx, dst);
                     break;
                 case GGML_UNARY_OP_GELU:
-                    func = ggml_sycl_gelu;
+                    ggml_sycl_gelu(ctx, dst);
                     break;
                 case GGML_UNARY_OP_SILU:
-                    func = ggml_sycl_silu;
+                    ggml_sycl_silu(ctx, dst);
                     break;
                 case GGML_UNARY_OP_GELU_QUICK:
-                    func = ggml_sycl_gelu_quick;
+                    ggml_sycl_gelu_quick(ctx, dst);
                     break;
                 case GGML_UNARY_OP_TANH:
-                    func = ggml_sycl_tanh;
+                    ggml_sycl_tanh(ctx, dst);
                     break;
                 case GGML_UNARY_OP_RELU:
-                    func = ggml_sycl_relu;
+                    ggml_sycl_relu(ctx, dst);
                     break;
                 case GGML_UNARY_OP_SIGMOID:
-                    func = ggml_sycl_sigmoid;
+                    ggml_sycl_sigmoid(ctx, dst);
                     break;
                 case GGML_UNARY_OP_HARDSIGMOID:
-                    func = ggml_sycl_hardsigmoid;
+                    ggml_sycl_hardsigmoid(ctx, dst);
                     break;
                 case GGML_UNARY_OP_HARDSWISH:
-                    func = ggml_sycl_hardswish;
+                    ggml_sycl_hardswish(ctx, dst);
                     break;
                 case GGML_UNARY_OP_EXP:
-                    func = ggml_sycl_exp;
+                    ggml_sycl_exp(ctx, dst);
                     break;
                 default:
                     return false;
             }
             break;
         case GGML_OP_NORM:
-            func = ggml_sycl_norm;
+            ggml_sycl_norm(ctx, dst);
             break;
         case GGML_OP_GROUP_NORM:
-            func = ggml_sycl_group_norm;
+            ggml_sycl_group_norm(ctx, dst);
             break;
         case GGML_OP_CONCAT:
-            func = ggml_sycl_op_concat;
+            ggml_sycl_op_concat(ctx, dst);
             break;
         case GGML_OP_UPSCALE:
-            func = ggml_sycl_upscale;
+            ggml_sycl_upscale(ctx, dst);
             break;
         case GGML_OP_PAD:
-            func = ggml_sycl_pad;
+            ggml_sycl_pad(ctx, dst);
             break;
         case GGML_OP_LEAKY_RELU:
-            func = ggml_sycl_leaky_relu;
+            ggml_sycl_leaky_relu(ctx, dst);
             break;
         case GGML_OP_RMS_NORM:
-            func = ggml_sycl_rms_norm;
+            ggml_sycl_rms_norm(ctx, dst);
             break;
         case GGML_OP_MUL_MAT:
-            if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
+            if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
                 return false;
             }
-            func = ggml_sycl_mul_mat;
+            /* ggml_sycl_mul_mat_id is dependent on ggml_sycl_mul_mat */
+            ggml_sycl_mul_mat(ctx, dst->src[0], dst->src[1], dst);
             break;
         case GGML_OP_MUL_MAT_ID:
-            if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
+            if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
                 return false;
             }
-            func = ggml_sycl_mul_mat_id;
+            ggml_sycl_mul_mat_id(ctx, dst);
             break;
         case GGML_OP_OUT_PROD:
-            func = ggml_sycl_op_out_prod;
+            ggml_sycl_op_out_prod(ctx, dst);
             break;
         case GGML_OP_SCALE:
-            func = ggml_sycl_scale;
+            ggml_sycl_scale(ctx, dst);
             break;
         case GGML_OP_SQR:
-            func = ggml_sycl_sqr;
+            ggml_sycl_sqr(ctx, dst);
             break;
         case GGML_OP_SQRT:
-            func = ggml_sycl_sqrt;
+            ggml_sycl_sqrt(ctx, dst);
             break;
         case GGML_OP_SIN:
-            func = ggml_sycl_sin;
+            ggml_sycl_sin(ctx, dst);
             break;
         case GGML_OP_COS:
-            func = ggml_sycl_cos;
+            ggml_sycl_cos(ctx, dst);
             break;
         case GGML_OP_CLAMP:
-            func = ggml_sycl_clamp;
+            ggml_sycl_clamp(ctx, dst);
             break;
         case GGML_OP_CPY:
-            func = ggml_sycl_cpy;
+            ggml_sycl_cpy(ctx, dst->src[0], dst->src[1], dst);
             break;
         case GGML_OP_CONT:
-            func = ggml_sycl_dup;
+            ggml_sycl_dup(ctx, dst);
             break;
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
         case GGML_OP_VIEW:
         case GGML_OP_PERMUTE:
         case GGML_OP_TRANSPOSE:
-            func = ggml_sycl_nop;
+            GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__);
             break;
         case GGML_OP_DIAG_MASK_INF:
-            func = ggml_sycl_diag_mask_inf;
+            ggml_sycl_diag_mask_inf(ctx, dst);
             break;
         case GGML_OP_SOFT_MAX:
-            func = ggml_sycl_soft_max;
+            ggml_sycl_soft_max(ctx, dst);
             break;
         case GGML_OP_ROPE:
-            func = ggml_sycl_rope;
+            ggml_sycl_rope(ctx, dst);
             break;
         case GGML_OP_IM2COL:
-            func = ggml_sycl_im2col;
+            ggml_sycl_im2col(ctx, dst);
             break;
         case GGML_OP_POOL_2D:
-            func = ggml_sycl_pool2d;
+            ggml_sycl_pool2d(ctx, dst);
             break;
         case GGML_OP_SUM:
-            func = ggml_sycl_sum;
+            ggml_sycl_sum(ctx, dst);
             break;
         case GGML_OP_SUM_ROWS:
-            func = ggml_sycl_sum_rows;
+            ggml_sycl_sum_rows(ctx, dst);
             break;
         case GGML_OP_ARGSORT:
-            func = ggml_sycl_argsort;
+            ggml_sycl_argsort(ctx, dst);
             break;
         case GGML_OP_TIMESTEP_EMBEDDING:
-            func = ggml_sycl_op_timestep_embedding;
+            ggml_sycl_op_timestep_embedding(ctx, dst);
             break;
         case GGML_OP_RWKV_WKV6:
-            func = ggml_sycl_op_rwkv_wkv6;
+            ggml_sycl_op_rwkv_wkv6(ctx, dst);
             break;
         default:
             return false;
     }
 
-    if (tensor->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(tensor->src[0]->buffer)) {
-        ggml_sycl_set_peer_access(tensor->src[1]->ne[1], ctx.device);
-    }
-
-    func(ctx, tensor->src[0], tensor->src[1], tensor);
     return true;
 }
 
diff --git a/ggml/src/ggml-sycl/outprod.cpp b/ggml/src/ggml-sycl/outprod.cpp
index ef9af0b7633..8e8347ff4f9 100644
--- a/ggml/src/ggml-sycl/outprod.cpp
+++ b/ggml/src/ggml-sycl/outprod.cpp
@@ -3,9 +3,9 @@
 #include "outprod.hpp"
 
 
-void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
-    const ggml_tensor* src1, ggml_tensor* dst) {
-
+void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
+    const ggml_tensor *src0 = dst->src[0];
+    const ggml_tensor *src1 = dst->src[1];
 
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
diff --git a/ggml/src/ggml-sycl/outprod.hpp b/ggml/src/ggml-sycl/outprod.hpp
index 9c042738a48..f50413d3f7a 100644
--- a/ggml/src/ggml-sycl/outprod.hpp
+++ b/ggml/src/ggml-sycl/outprod.hpp
@@ -3,8 +3,7 @@
 
 #include "common.hpp"
 
-void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
-    const ggml_tensor* src1, ggml_tensor* dst);
+void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
 
 
 #endif // GGML_SYCL_OUTPROD_HPP
diff --git a/ggml/src/ggml-sycl/tsembd.cpp b/ggml/src/ggml-sycl/tsembd.cpp
index 2ffe3cca917..b877d18c173 100644
--- a/ggml/src/ggml-sycl/tsembd.cpp
+++ b/ggml/src/ggml-sycl/tsembd.cpp
@@ -55,8 +55,9 @@ static void timestep_embedding_f32_sycl(
         });
 }
 
-void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-    const ggml_tensor *src1, ggml_tensor * dst) {
+void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor *src0 = dst->src[0];
+    const ggml_tensor *src1 = dst->src[1];
     const float * src0_d = (const float *)src0->data;
     float * dst_d = (float *)dst->data;
     dpct::queue_ptr stream = ctx.stream();
diff --git a/ggml/src/ggml-sycl/tsembd.hpp b/ggml/src/ggml-sycl/tsembd.hpp
index ff854c337c3..4c18748bbff 100644
--- a/ggml/src/ggml-sycl/tsembd.hpp
+++ b/ggml/src/ggml-sycl/tsembd.hpp
@@ -15,7 +15,6 @@
 
 #include "common.hpp"
 
-void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-    const ggml_tensor *src1, ggml_tensor * dst);
+void ggml_sycl_op_timestep_embedding(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
 #endif // GGML_SYCL_TSEMBD_HPP
diff --git a/ggml/src/ggml-sycl/wkv6.cpp b/ggml/src/ggml-sycl/wkv6.cpp
index 75ddfb86ac0..b54c20964ed 100644
--- a/ggml/src/ggml-sycl/wkv6.cpp
+++ b/ggml/src/ggml-sycl/wkv6.cpp
@@ -95,8 +95,10 @@ static void rwkv_wkv_f32_kernel(
     }
 }
 
-void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
-    const ggml_tensor* src1, ggml_tensor* dst) {
+void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
+
+    const ggml_tensor *src0 = dst->src[0];
+    const ggml_tensor *src1 = dst->src[1];
 
     const float* k_d = (const float*)dst->src[0]->data;
     const float* v_d = (const float*)dst->src[1]->data;
@@ -107,9 +109,9 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* s
     float* dst_d = (float*)dst->data;
 
     const int64_t B = dst->src[5]->ne[1];
-    const int64_t T = dst->src[0]->ne[3];
+    const int64_t T = dst->src[0]->ne[2];
     const int64_t C = dst->ne[0];
-    const int64_t H = dst->src[0]->ne[2];
+    const int64_t H = dst->src[0]->ne[1];
 
     GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
     GGML_ASSERT(C % H == 0);
@@ -131,7 +133,7 @@ void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* s
             [=](sycl::nd_item<3> item_ct1) {
                 rwkv_wkv_f32_kernel(
                     B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
-                    item_ct1, shared_mem_acc.get_pointer()
+                    item_ct1, (float*)shared_mem_acc.get_multi_ptr().get()
                 );
             });
     });
diff --git a/ggml/src/ggml-sycl/wkv6.hpp b/ggml/src/ggml-sycl/wkv6.hpp
index ddfa3377b48..8c596a99722 100644
--- a/ggml/src/ggml-sycl/wkv6.hpp
+++ b/ggml/src/ggml-sycl/wkv6.hpp
@@ -3,8 +3,7 @@
 
 #include "common.hpp"
 
-void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-    const ggml_tensor *src1, ggml_tensor * dst);
+void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
 
 #endif // GGML_SYCL_WKV6_HPP
diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt
index 6d46e5f24c1..c0ddaac827f 100644
--- a/ggml/src/ggml-vulkan/CMakeLists.txt
+++ b/ggml/src/ggml-vulkan/CMakeLists.txt
@@ -8,6 +8,20 @@ if (Vulkan_FOUND)
                              ../../include/ggml-vulkan.h
                             )
 
+    # Compile a test shader to determine whether GL_KHR_cooperative_matrix is supported.
+    # If it's not, there will be an error to stderr.
+    # If it's supported, set a define to indicate that we should compile those shaders
+    execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat_support.comp"
+                    OUTPUT_VARIABLE glslc_output
+                    ERROR_VARIABLE glslc_error)
+
+    if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*")
+        message(STATUS "GL_KHR_cooperative_matrix not supported by glslc")
+    else()
+        message(STATUS "GL_KHR_cooperative_matrix supported by glslc")
+        add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
+    endif()
+
     # Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported.
     # If it's not, there will be an error to stderr.
     # If it's supported, set a define to indicate that we should compile those shaders
@@ -69,6 +83,10 @@ if (Vulkan_FOUND)
 
     file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp")
 
+    if (NOT CMAKE_CROSSCOMPILING)
+        set(_ggml_vk_genshaders_cmd "$/${_ggml_vk_genshaders_cmd}")
+    endif ()
+
     add_custom_command(
         OUTPUT ${_ggml_vk_header}
                 ${_ggml_vk_source}
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 020e612801f..649146d7b45 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -1645,6 +1645,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
 #undef CREATE_MM2
     } else
 #endif  // defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
+#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
     if (device->coopmat_support) {
         // Create 6 variants, {s,m,l}x{unaligned,aligned}
 #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
@@ -1739,7 +1740,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
         }
 #undef CREATE_MM2
 #undef CREATE_MM
-    } else if (device->fp16) {
+    } else
+#endif  // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
+    if (device->fp16) {
         // Create 6 variants, {s,m,l}x{unaligned,aligned}
 #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
         if (device->mul_mat ## ID ## _l) \
@@ -2040,6 +2043,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
     std::cerr << "Done!" << std::endl;
 }
 
+static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props);
+
 static vk_device ggml_vk_get_device(size_t idx) {
     VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
 
@@ -2175,9 +2180,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
 
         device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
 
-        if (device->vendor_id == VK_VENDOR_ID_INTEL || (device->vendor_id == VK_VENDOR_ID_AMD && (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource))) {
-            // Intel drivers don't support coopmat properly yet
-            // Only RADV supports coopmat properly on AMD
+        if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props)) {
             device->coopmat_support = false;
         }
 
@@ -2242,6 +2245,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
             last_struct = (VkBaseOutStructure *)&subgroup_size_control_features;
         }
 
+#if defined(VK_KHR_cooperative_matrix)
         VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
         coopmat_features.pNext = nullptr;
         coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
@@ -2251,6 +2255,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
             last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
             last_struct = (VkBaseOutStructure *)&coopmat_features;
         }
+#endif
 
 #if defined(VK_NV_cooperative_matrix2)
         VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {};
@@ -2272,6 +2277,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
         if (device->subgroup_size_control) {
             device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
             device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
+            device_extensions.push_back("VK_EXT_subgroup_size_control");
         }
 
         device->subgroup_size_control = device->subgroup_size_control &&
@@ -2280,10 +2286,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
 
         if (device->subgroup_size_control) {
             device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
-            device_extensions.push_back("VK_EXT_subgroup_size_control");
         }
 
+#if defined(VK_KHR_cooperative_matrix)
         device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
+#endif
 
         if (coopmat2_support) {
 #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
@@ -2376,6 +2383,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
             device_extensions.push_back("VK_KHR_shader_float16_int8");
         }
 
+#if defined(VK_KHR_cooperative_matrix)
         if (device->coopmat_support) {
             // Query supported shapes
             std::vector cm_props;
@@ -2442,7 +2450,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
         if (device->coopmat_support) {
             device_extensions.push_back("VK_KHR_cooperative_matrix");
         }
-
+#endif
         device->name = GGML_VK_NAME + std::to_string(idx);
 
         device_create_info = {
@@ -2515,7 +2523,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
     return vk_instance.devices[idx];
 }
 
-
 static void ggml_vk_print_gpu_info(size_t idx) {
     GGML_ASSERT(idx < vk_instance.device_indices.size());
     size_t dev_num = vk_instance.device_indices[idx];
@@ -2554,9 +2561,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
             fp16_storage = true;
         } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
             fp16_compute = true;
-        } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
+#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
+       } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
                    !getenv("GGML_VK_DISABLE_COOPMAT")) {
             coopmat_support = true;
+#endif
 #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
         } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
                    !getenv("GGML_VK_DISABLE_COOPMAT2")) {
@@ -2565,9 +2574,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
         }
     }
 
-    if (props2.properties.vendorID == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource))) {
-        // Intel drivers don't support coopmat properly yet
-        // Only RADV supports coopmat properly on AMD
+    if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props)) {
         coopmat_support = false;
     }
 
@@ -2596,6 +2603,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
     // Pointer to the last chain element
     VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features;
 
+#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
     VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
     coopmat_features.pNext = nullptr;
     coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
@@ -2611,6 +2619,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
     fp16 = fp16 && vk12_features.shaderFloat16;
 
     coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix;
+#endif
 
     std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
 
@@ -5624,9 +5633,9 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
 }
 
 static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
-    const size_t seq_length = dst->src[0]->ne[3];
+    const size_t seq_length = dst->src[0]->ne[2];
     const size_t n_embed = dst->ne[0];
-    const size_t n_heads = dst->src[0]->ne[2];
+    const size_t n_heads = dst->src[0]->ne[1];
     const size_t n_seqs = dst->src[5]->ne[1];
 
     ggml_vk_op_f32_rwkv6(
@@ -8088,6 +8097,25 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
     UNUSED(instance_extensions);
 }
 
+static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) {
+    switch (props.vendorID) {
+    case VK_VENDOR_ID_INTEL:
+        // Intel drivers don't support coopmat properly yet
+        return false;
+    case VK_VENDOR_ID_AMD:
+        if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
+            // Workaround for AMD proprietary driver reporting support on all GPUs
+            const std::string name = props.deviceName;
+            return name.rfind("AMD Radeon RX 7", 0) == 0   || name.rfind("AMD Radeon(TM) RX 7", 0) == 0   || // RDNA 3 consumer GPUs
+                   name.rfind("AMD Radeon PRO W7", 0) == 0 || name.rfind("AMD Radeon(TM) PRO W7", 0) == 0 || // RDNA 3 workstation GPUs
+                   name.rfind("AMD Radeon 7", 0) == 0      || name.rfind("AMD Radeon(TM) 7", 0) == 0;        // RDNA 3 APUs
+        }
+        return true;
+    default:
+        return true;
+    }
+}
+
 // checks
 
 #ifdef GGML_VULKAN_CHECK_RESULTS
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
index 24875cdcf4c..53902858de7 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
@@ -1,9 +1,6 @@
 #version 450
 
-#ifdef FLOAT16
-#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
-#endif
-#extension GL_EXT_shader_explicit_arithmetic_types : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
 
 #include "mul_mat_vec_base.comp"
 
@@ -27,8 +24,8 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
 
 #if K_PER_ITER == 8
 #if QUANT_R == 2
-        const B_TYPE_VEC4 bv02 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4];
-        const B_TYPE_VEC4 bv13 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4];
+        const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
+        const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]);
         const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
         const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
 #else
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
index 93421344624..6a9b9b2d132 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
@@ -1,5 +1,5 @@
 #version 450
-#extension GL_EXT_shader_explicit_arithmetic_types : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
 
 #include "mul_mat_vec_base.comp"
 
@@ -40,9 +40,9 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
 
         [[unroll]] for (uint n = 0; n < num_rows; ++n) {
             const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
-            f16vec2 d = data_a[ib0 + i].d;
-            const FLOAT_TYPE dall = d.x;
-            const FLOAT_TYPE dmin = d.y;
+            vec2 d = vec2(data_a[ib0 + i].d);
+            const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
+            const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
 
             uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0];
             uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1];
@@ -63,14 +63,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
             uvec2 qs16 = uvec2(unpack8(qs16_u16));
 
             [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
-                B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0];
-                B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
-                B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
-                B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
-                B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
-                B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
-                B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
-                B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
+                vec2 b0 =   vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  0]);
+                vec2 b16 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  8]);
+                vec2 b32 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
+                vec2 b48 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
+                vec2 b64 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
+                vec2 b80 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
+                vec2 b96 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
+                vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
 
                 FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
                 FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp
index 86b0159d97a..96ef50fdda2 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp
@@ -1,5 +1,5 @@
 #version 450
-#extension GL_EXT_shader_explicit_arithmetic_types : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
 
 #include "mul_mat_vec_base.comp"
 
@@ -60,14 +60,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
 
             [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
 
-                B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0];
-                B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
-                B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
-                B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
-                B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
-                B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
-                B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
-                B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
+                vec2 b0 =   vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  0]);
+                vec2 b16 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  8]);
+                vec2 b32 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
+                vec2 b48 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
+                vec2 b64 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
+                vec2 b80 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
+                vec2 b96 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
+                vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
 
                 FLOAT_TYPE sum = FLOAT_TYPE(0.0);
                 [[unroll]] for (int l = 0; l < 2; ++l) {
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
index cd1dd8e89c2..f97eb8744fb 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp
@@ -1,6 +1,6 @@
 #version 450
 
-#extension GL_EXT_shader_explicit_arithmetic_types : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
 
 #include "mul_mat_vec_base.comp"
 
@@ -45,7 +45,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
 
         [[unroll]] for (uint n = 0; n < num_rows; ++n) {
             const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
-            f16vec2 d = data_a[ib0 + i].d;
+            vec2 d = vec2(data_a[ib0 + i].d);
             const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
             const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
 
@@ -96,10 +96,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
             const uint32_t q4_15 = qs64_hi4.w;
 
             [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
-                B_TYPE_VEC4 by10 =  data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4];
-                B_TYPE_VEC4 by132 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8];
-                B_TYPE_VEC4 by20 =  data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4];
-                B_TYPE_VEC4 by232 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8];
+                vec4 by10 =  vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4    ]);
+                vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]);
+                vec4 by20 =  vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4    ]);
+                vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]);
 
                 const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x),      q4_0,  fma(FLOAT_TYPE(by10.y),  q4_1,  fma(FLOAT_TYPE(by10.z),  q4_2,  FLOAT_TYPE(by10.w) *  q4_3)));
                 const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x),     q4_4,  fma(FLOAT_TYPE(by132.y), q4_5,  fma(FLOAT_TYPE(by132.z), q4_6,  FLOAT_TYPE(by132.w) * q4_7)));
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
index 0a68891c35a..79d7db0e3e6 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp
@@ -1,6 +1,6 @@
 #version 450
 
-#extension GL_EXT_shader_explicit_arithmetic_types : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
 
 #include "mul_mat_vec_base.comp"
 
@@ -42,7 +42,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
 
         [[unroll]] for (uint n = 0; n < num_rows; ++n) {
             const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
-            f16vec2 d = data_a[ib0 + i].d;
+            vec2 d = vec2(data_a[ib0 + i].d);
             const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
             const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
 
@@ -105,14 +105,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
             const uint32_t q4_15 = qs64_80_hi4.w;
 
             [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
-                B_TYPE_VEC2 by10 =  data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2];
-                B_TYPE_VEC2 by116 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8];
-                B_TYPE_VEC2 by132 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16];
-                B_TYPE_VEC2 by148 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24];
-                B_TYPE_VEC2 by20 =  data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2];
-                B_TYPE_VEC2 by216 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8];
-                B_TYPE_VEC2 by232 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16];
-                B_TYPE_VEC2 by248 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24];
+                vec2 by10 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2     ]);
+                vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 +  8]);
+                vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]);
+                vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]);
+                vec2 by20 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2     ]);
+                vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 +  8]);
+                vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]);
+                vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]);
 
                 const FLOAT_TYPE sx =
                   fma(FLOAT_TYPE(by10.x), q4_0,
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp
index 70e13a56bd7..041fd27c12b 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp
@@ -1,6 +1,6 @@
 #version 450
 
-#extension GL_EXT_shader_explicit_arithmetic_types : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
 
 #include "mul_mat_vec_base.comp"
 
@@ -77,10 +77,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
             uvec4 q3 = uvec4(unpack8(q3_u32));
 
             [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
-                B_TYPE_VEC4 by0  = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4];
-                B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8];
-                B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16];
-                B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24];
+                vec4 by0  = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4     ]);
+                vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 +  8]);
+                vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]);
+                vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]);
 
                 FLOAT_TYPE sum = FLOAT_TYPE(0.0);
                 [[unroll]] for (int l = 0; l < 4; ++l) {
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp
index a25808e1656..51fc2dc7ed4 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp
@@ -1,6 +1,5 @@
 #version 450
 
-#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
 #extension GL_EXT_control_flow_attributes : enable
 
 layout (push_constant) uniform parameter
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp b/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp
new file mode 100644
index 00000000000..8c5dd1bd167
--- /dev/null
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/test_coopmat_support.comp
@@ -0,0 +1,7 @@
+#version 460
+
+#extension GL_KHR_cooperative_matrix : require
+
+void main()
+{
+}
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp
index eecc47f3a97..f12e61bbe10 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp
@@ -2,7 +2,10 @@
 #if !defined(GGML_TYPES_COMP)
 #define GGML_TYPES_COMP
 
-#extension GL_EXT_shader_explicit_arithmetic_types : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
+#extension GL_EXT_shader_16bit_storage : require
 
 #if defined(DATA_A_F32)
 #define QUANT_K 1
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
index 8111c063884..7b5044798d7 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
@@ -342,9 +342,11 @@ void process_shaders() {
         matmul_shaders(true, matmul_id, false, false, false);
         matmul_shaders(true, matmul_id, false, false, true);
 
+#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
         // Coopmat, fp32acc and fp16acc
         matmul_shaders(true, matmul_id, true, false, false);
         matmul_shaders(true, matmul_id, true, false, true);
+#endif
 
 #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
         // Coopmat2, fp32acc and fp16acc
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 2bbe5f48257..da5b817e156 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -968,6 +968,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "GET_REL_POS",
     "ADD_REL_POS",
     "RWKV_WKV6",
+    "GATED_LINEAR_ATTN",
 
     "UNARY",
 
@@ -987,7 +988,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "OPT_STEP_ADAMW",
 };
 
-static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
+static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1064,6 +1065,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "get_rel_pos(x)",
     "add_rel_pos(x)",
     "rwkv_wkv6(k, v, r, tf, td, s)",
+    "gated_linear_attn(k, v, q, gate, s)",
 
     "unary(x)",
 
@@ -1083,7 +1085,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "adamw(x)",
 };
 
-static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
+static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -1588,15 +1590,8 @@ static struct ggml_tensor * ggml_new_tensor_impl(
 
     struct ggml_tensor * const result = (struct ggml_tensor *)((char *)ctx->mem_buffer + obj_new->offs);
 
-#ifdef __clang__
-    // temporary until ggml_tensor::backend is removed
-    #pragma clang diagnostic push
-    #pragma clang diagnostic ignored "-Wdeprecated-declarations"
-#endif
-
     *result = (struct ggml_tensor) {
         /*.type         =*/ type,
-        /*.backend      =*/ GGML_BACKEND_TYPE_CPU,
         /*.buffer       =*/ NULL,
         /*.ne           =*/ { 1, 1, 1, 1 },
         /*.nb           =*/ { 0, 0, 0, 0 },
@@ -1612,10 +1607,6 @@ static struct ggml_tensor * ggml_new_tensor_impl(
         /*.padding      =*/ { 0 },
     };
 
-#ifdef __clang__
-    #pragma clang diagnostic pop
-#endif
-
     // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
     //GGML_ASSERT_ALIGNED(result->data);
 
@@ -4640,15 +4631,13 @@ struct ggml_tensor * ggml_rwkv_wkv6(
     GGML_ASSERT(ggml_is_contiguous(state));
 
     const int64_t S = k->ne[0];
-    const int64_t H = k->ne[2];
-    const int64_t n_tokens = k->ne[3];
+    const int64_t H = k->ne[1];
+    const int64_t n_tokens = k->ne[2];
     const int64_t n_seqs = state->ne[1];
     {
-        GGML_ASSERT(k->ne[1] == 1);
-        GGML_ASSERT(v->ne[0] == 1 && v->ne[1] == S && v->ne[2] == H && v->ne[3] == n_tokens);
-        GGML_ASSERT(r->ne[0] == 1 && r->ne[1] == S && r->ne[2] == H && r->ne[3] == n_tokens);
-        // TODO: RWKV v4 and v5
-        GGML_ASSERT(td->ne[0] == 1 && td->ne[1] == S && td->ne[2] == H && td->ne[3] == n_tokens);
+        GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
+        GGML_ASSERT(r->ne[0] == S && r->ne[1] == H && r->ne[2] == n_tokens);
+        GGML_ASSERT(td->ne[0] == S && td->ne[1] == H && td->ne[2] == n_tokens);
         GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
     }
 
@@ -4667,6 +4656,49 @@ struct ggml_tensor * ggml_rwkv_wkv6(
     return result;
 }
 
+// ggml_gated_linear_attn
+
+struct ggml_tensor * ggml_gated_linear_attn(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * k,
+        struct ggml_tensor  * v,
+        struct ggml_tensor  * q,
+        struct ggml_tensor  * g,
+        struct ggml_tensor  * state,
+        float scale) {
+    GGML_ASSERT(ggml_is_contiguous(k));
+    GGML_ASSERT(ggml_is_contiguous(v));
+    GGML_ASSERT(ggml_is_contiguous(q));
+    GGML_ASSERT(ggml_is_contiguous(g));
+    GGML_ASSERT(ggml_is_contiguous(state));
+
+    const int64_t S = k->ne[0];
+    const int64_t H = k->ne[1];
+    const int64_t n_tokens = k->ne[2];
+    const int64_t n_seqs = state->ne[1];
+    {
+        GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens);
+        GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens);
+        GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens);
+        GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs);
+    }
+
+    // concat output and new_state
+    const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
+    struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
+
+    ggml_set_op_params_f32(result, 0, scale);
+
+    result->op     = GGML_OP_GATED_LINEAR_ATTN;
+    result->src[0] = k;
+    result->src[1] = v;
+    result->src[2] = q;
+    result->src[3] = g;
+    result->src[4] = state;
+
+    return result;
+}
+
 // ggml_unary
 
 static struct ggml_tensor * ggml_unary_impl(
@@ -6417,1271 +6449,6 @@ size_t ggml_quantize_chunk(
 
 ////////////////////////////////////////////////////////////////////////////////
 
-struct gguf_str {
-    uint64_t n;  // GGUFv2
-    char * data;
-};
-
-static const size_t GGUF_TYPE_SIZE[GGUF_TYPE_COUNT] = {
-    [GGUF_TYPE_UINT8]   = sizeof(uint8_t),
-    [GGUF_TYPE_INT8]    = sizeof(int8_t),
-    [GGUF_TYPE_UINT16]  = sizeof(uint16_t),
-    [GGUF_TYPE_INT16]   = sizeof(int16_t),
-    [GGUF_TYPE_UINT32]  = sizeof(uint32_t),
-    [GGUF_TYPE_INT32]   = sizeof(int32_t),
-    [GGUF_TYPE_FLOAT32] = sizeof(float),
-    [GGUF_TYPE_BOOL]    = sizeof(bool),
-    [GGUF_TYPE_STRING]  = sizeof(struct gguf_str),
-    [GGUF_TYPE_UINT64]  = sizeof(uint64_t),
-    [GGUF_TYPE_INT64]   = sizeof(int64_t),
-    [GGUF_TYPE_FLOAT64] = sizeof(double),
-    [GGUF_TYPE_ARRAY]   = 0, // undefined
-};
-static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
-
-static const char * GGUF_TYPE_NAME[GGUF_TYPE_COUNT] = {
-    [GGUF_TYPE_UINT8]   = "u8",
-    [GGUF_TYPE_INT8]    = "i8",
-    [GGUF_TYPE_UINT16]  = "u16",
-    [GGUF_TYPE_INT16]   = "i16",
-    [GGUF_TYPE_UINT32]  = "u32",
-    [GGUF_TYPE_INT32]   = "i32",
-    [GGUF_TYPE_FLOAT32] = "f32",
-    [GGUF_TYPE_BOOL]    = "bool",
-    [GGUF_TYPE_STRING]  = "str",
-    [GGUF_TYPE_ARRAY]   = "arr",
-    [GGUF_TYPE_UINT64]  = "u64",
-    [GGUF_TYPE_INT64]   = "i64",
-    [GGUF_TYPE_FLOAT64] = "f64",
-};
-static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
-
-union gguf_value {
-    uint8_t  uint8;
-    int8_t   int8;
-    uint16_t uint16;
-    int16_t  int16;
-    uint32_t uint32;
-    int32_t  int32;
-    float    float32;
-    uint64_t uint64;
-    int64_t  int64;
-    double   float64;
-    bool     bool_;
-
-    struct gguf_str str;
-
-    struct {
-        enum gguf_type type;
-
-        uint64_t n;  // GGUFv2
-        void * data;
-    } arr;
-};
-
-struct gguf_kv {
-    struct gguf_str key;
-
-    enum  gguf_type  type;
-    union gguf_value value;
-};
-
-struct gguf_header {
-    char magic[4];
-
-    uint32_t version;
-    uint64_t n_tensors; // GGUFv2
-    uint64_t n_kv;      // GGUFv2
-};
-
-struct gguf_tensor_info {
-    struct gguf_str name;
-
-    uint32_t n_dims;
-    uint64_t ne[GGML_MAX_DIMS];
-
-    enum ggml_type type;
-
-    uint64_t offset; // offset from start of `data`, must be a multiple of `ALIGNMENT`
-
-    // for writing API
-    const void * data;
-    size_t size;
-};
-
-struct gguf_context {
-    struct gguf_header header;
-
-    struct gguf_kv          * kv;
-    struct gguf_tensor_info * infos;
-
-    size_t alignment;
-    size_t offset;    // offset of `data` from beginning of file
-    size_t size;      // size of `data` in bytes
-
-    //uint8_t * padding;
-    void * data;
-};
-
-size_t gguf_type_size(enum gguf_type type) {
-    GGML_ASSERT(0 <= type && type < GGUF_TYPE_COUNT);
-    return GGUF_TYPE_SIZE[type];
-}
-
-static bool gguf_tensor_info_sanitize(struct gguf_tensor_info * info) {
-    if (info->n_dims > GGML_MAX_DIMS) {
-        fprintf(stderr, "%s: invalid number of dimensions (%" PRIu32 ")\n", __func__, info->n_dims);
-        return false;
-    }
-
-    if (info->type < 0 || info->type >= GGML_TYPE_COUNT) {
-        fprintf(stderr, "%s: invalid type (%d)\n", __func__, info->type);
-        return false;
-    }
-
-    if (strlen(info->name.data) >= GGML_MAX_NAME) {
-        fprintf(stderr, "%s: tensor '%s' name is too long\n", __func__, info->name.data);
-        return false;
-    }
-
-    for (uint32_t i = 0; i < info->n_dims; ++i) {
-        if (info->ne[i] <= 0) {
-            fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[i]);
-            return false;
-        }
-    }
-
-    // prevent overflow for total number of elements
-    if (INT64_MAX/info->ne[1] <= info->ne[0]) {
-        fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[1]);
-        return false;
-    }
-
-    if (INT64_MAX/info->ne[2] <= info->ne[0]*info->ne[1]) {
-        fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[2]);
-        return false;
-    }
-
-    if (INT64_MAX/info->ne[3] <= info->ne[0]*info->ne[1]*info->ne[2]) {
-        fprintf(stderr, "%s: invalid number of elements (%" PRIu64 ")\n", __func__, info->ne[3]);
-        return false;
-    }
-
-    return true;
-}
-
-static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) {
-    const size_t n = fread(dst, 1, size, file);
-    *offset += n;
-    return n == size;
-}
-
-static bool gguf_fread_str(FILE * file, struct gguf_str * p, size_t * offset) {
-    p->n    = 0;
-    p->data = NULL;
-
-    bool ok = true;
-
-    ok = ok && gguf_fread_el(file, &p->n, sizeof(p->n), offset);
-
-    // early exit if string length is invalid, prevents from integer overflow
-    if (p->n == SIZE_MAX) {
-        fprintf(stderr, "%s: invalid string length (%" PRIu64 ")\n", __func__, p->n);
-        return false;
-    }
-
-    p->data = calloc(p->n + 1, 1);
-    if (!p->data) {
-        fprintf(stderr, "%s: failed to allocate memory for string of length %" PRIu64 "\n", __func__, p->n);
-        return false;
-    }
-
-    ok = ok && gguf_fread_el(file,  p->data, p->n, offset);
-
-    return ok;
-}
-
-static void gguf_free_kv(struct gguf_kv * kv) {
-    if (kv->key.data) {
-        GGML_FREE(kv->key.data);
-    }
-
-    if (kv->type == GGUF_TYPE_STRING) {
-        if (kv->value.str.data) {
-            GGML_FREE(kv->value.str.data);
-        }
-    }
-
-    if (kv->type == GGUF_TYPE_ARRAY) {
-        if (kv->value.arr.data) {
-            if (kv->value.arr.type == GGUF_TYPE_STRING) {
-                for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
-                    struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[j];
-                    if (str->data) {
-                        GGML_FREE(str->data);
-                    }
-                }
-            }
-            GGML_FREE(kv->value.arr.data);
-        }
-    }
-}
-
-struct gguf_context * gguf_init_empty(void) {
-    struct gguf_context * ctx = calloc(1, sizeof(struct gguf_context));
-    if (!ctx) {
-        fprintf(stderr, "%s: failed to allocate memory for context\n", __func__);
-        return NULL;
-    }
-
-    memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic));
-    ctx->header.version   = GGUF_VERSION;
-    ctx->header.n_tensors = 0;
-    ctx->header.n_kv      = 0;
-
-    ctx->kv    = NULL;
-    ctx->infos = NULL;
-
-    ctx->alignment = GGUF_DEFAULT_ALIGNMENT;
-    ctx->offset    = 0;
-    ctx->size      = 0;
-
-    ctx->data = NULL;
-
-    return ctx;
-}
-
-struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params) {
-    // offset from start of file
-    size_t offset = 0;
-
-    char magic[4];
-
-    // check the magic before making allocations
-    {
-        gguf_fread_el(file, &magic, sizeof(magic), &offset);
-
-        for (uint32_t i = 0; i < sizeof(magic); i++) {
-            if (magic[i] != GGUF_MAGIC[i]) {
-                fprintf(stderr, "%s: invalid magic characters '%c%c%c%c'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
-                return NULL;
-            }
-        }
-    }
-
-    bool ok = true;
-
-    struct gguf_context * ctx = calloc(1, sizeof(struct gguf_context));
-    if (!ctx) {
-        fprintf(stderr, "%s: failed to allocate memory for context\n", __func__);
-        return NULL;
-    }
-
-    // read the header
-    {
-        strncpy(ctx->header.magic, magic, 4);
-
-        ctx->kv    = NULL;
-        ctx->infos = NULL;
-        ctx->data  = NULL;
-
-        ok = ok && gguf_fread_el(file, &ctx->header.version,   sizeof(ctx->header.version),   &offset);
-        ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset);
-        ok = ok && gguf_fread_el(file, &ctx->header.n_kv,      sizeof(ctx->header.n_kv),      &offset);
-
-        if (ctx->header.version == 1) {
-            fprintf(stderr, "%s: GGUFv1 is no longer supported. please use a more up-to-date version\n", __func__);
-            gguf_free(ctx);
-            return NULL;
-        }
-
-        // sanity-checks to prevent from integer/buffer overflows
-
-        ok = ok && (ctx->header.n_tensors < (SIZE_MAX/2)/sizeof(struct gguf_tensor_info));
-        ok = ok && (ctx->header.n_tensors < (SIZE_MAX/2)/ggml_tensor_overhead());
-        ok = ok && (ctx->header.n_kv      < (SIZE_MAX/2)/sizeof(struct gguf_kv));
-
-        if (!ok) {
-            fprintf(stderr, "%s: failed to read header\n", __func__);
-            gguf_free(ctx);
-            return NULL;
-        }
-    }
-
-    // read the kv pairs
-    {
-        const uint64_t n_kv = ctx->header.n_kv;
-
-        if (n_kv > 0) {
-            ctx->kv = calloc(n_kv, sizeof(struct gguf_kv));
-            if (!ctx->kv) {
-                fprintf(stderr, "%s: failed to allocate memory for kv pairs\n", __func__);
-                gguf_free(ctx);
-                return NULL;
-            }
-        }
-
-        for (uint64_t i = 0; i < n_kv; ++i) {
-            struct gguf_kv * kv = &ctx->kv[i];
-
-            //fprintf(stderr, "%s: reading kv %d\n", __func__, i);
-
-            ok = ok && gguf_fread_str(file, &kv->key,                    &offset);
-            ok = ok && gguf_fread_el (file, &kv->type, sizeof(kv->type), &offset);
-
-            //fprintf(stderr, "%s: reading kv with key %s\n", __func__, kv->key.data);
-
-            switch (kv->type) {
-                case GGUF_TYPE_UINT8:   ok = ok && gguf_fread_el (file, &kv->value.uint8,   sizeof(kv->value.uint8),   &offset); break;
-                case GGUF_TYPE_INT8:    ok = ok && gguf_fread_el (file, &kv->value.int8,    sizeof(kv->value.int8),    &offset); break;
-                case GGUF_TYPE_UINT16:  ok = ok && gguf_fread_el (file, &kv->value.uint16,  sizeof(kv->value.uint16),  &offset); break;
-                case GGUF_TYPE_INT16:   ok = ok && gguf_fread_el (file, &kv->value.int16,   sizeof(kv->value.int16),   &offset); break;
-                case GGUF_TYPE_UINT32:  ok = ok && gguf_fread_el (file, &kv->value.uint32,  sizeof(kv->value.uint32),  &offset); break;
-                case GGUF_TYPE_INT32:   ok = ok && gguf_fread_el (file, &kv->value.int32,   sizeof(kv->value.int32),   &offset); break;
-                case GGUF_TYPE_FLOAT32: ok = ok && gguf_fread_el (file, &kv->value.float32, sizeof(kv->value.float32), &offset); break;
-                case GGUF_TYPE_UINT64:  ok = ok && gguf_fread_el (file, &kv->value.uint64,  sizeof(kv->value.uint64),  &offset); break;
-                case GGUF_TYPE_INT64:   ok = ok && gguf_fread_el (file, &kv->value.int64,   sizeof(kv->value.int64),   &offset); break;
-                case GGUF_TYPE_FLOAT64: ok = ok && gguf_fread_el (file, &kv->value.float64, sizeof(kv->value.float64), &offset); break;
-                case GGUF_TYPE_BOOL:    ok = ok && gguf_fread_el (file, &kv->value.bool_,   sizeof(kv->value.bool_),   &offset); break;
-                case GGUF_TYPE_STRING:  ok = ok && gguf_fread_str(file, &kv->value.str,                                &offset); break;
-                case GGUF_TYPE_ARRAY:
-                    {
-                        ok = ok && gguf_fread_el(file, &kv->value.arr.type, sizeof(kv->value.arr.type), &offset);
-                        ok = ok && gguf_fread_el(file, &kv->value.arr.n,    sizeof(kv->value.arr.n),    &offset);
-
-                        switch (kv->value.arr.type) {
-                            case GGUF_TYPE_UINT8:
-                            case GGUF_TYPE_INT8:
-                            case GGUF_TYPE_UINT16:
-                            case GGUF_TYPE_INT16:
-                            case GGUF_TYPE_UINT32:
-                            case GGUF_TYPE_INT32:
-                            case GGUF_TYPE_FLOAT32:
-                            case GGUF_TYPE_UINT64:
-                            case GGUF_TYPE_INT64:
-                            case GGUF_TYPE_FLOAT64:
-                            case GGUF_TYPE_BOOL:
-                                {
-                                    // prevent from integer overflow in the malloc below
-                                    if (kv->value.arr.n >= SIZE_MAX/gguf_type_size(kv->value.arr.type)) {
-                                        fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n);
-                                        gguf_free(ctx);
-                                        return NULL;
-                                    }
-
-                                    kv->value.arr.data = calloc(kv->value.arr.n, gguf_type_size(kv->value.arr.type));
-                                    if (!kv->value.arr.data) {
-                                        fprintf(stderr, "%s: failed to allocate memory for array\n", __func__);
-                                        gguf_free(ctx);
-                                        return NULL;
-                                    }
-
-                                    ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type), &offset);
-                                } break;
-                            case GGUF_TYPE_STRING:
-                                {
-                                    // prevent from integer overflow in the malloc below
-                                    if (kv->value.arr.n >= SIZE_MAX/sizeof(struct gguf_str)) {
-                                        fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n);
-                                        gguf_free(ctx);
-                                        return NULL;
-                                    }
-
-                                    kv->value.arr.data = calloc(kv->value.arr.n, sizeof(struct gguf_str));
-                                    if (!kv->value.arr.data) {
-                                        fprintf(stderr, "%s: failed to allocate memory for array\n", __func__);
-                                        gguf_free(ctx);
-                                        return NULL;
-                                    }
-
-                                    for (uint64_t j = 0; j < kv->value.arr.n; ++j) {
-                                        ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset);
-                                    }
-                                } break;
-                            case GGUF_TYPE_ARRAY:
-                            default:
-                                {
-                                    fprintf(stderr, "%s: invalid array type %d\n", __func__, kv->value.arr.type);
-                                    ok = false;
-                                } break;
-                        }
-                    } break;
-                default:
-                    {
-                        fprintf(stderr, "%s: invalid type %d\n", __func__, kv->type);
-                        ok = false;
-                    } break;
-            }
-
-            if (!ok) {
-                break;
-            }
-        }
-
-        if (!ok) {
-            fprintf(stderr, "%s: failed to read key-value pairs\n", __func__);
-            gguf_free(ctx);
-            return NULL;
-        }
-    }
-
-    // read the tensor infos
-    if (ctx->header.n_tensors > 0) {
-        ctx->infos = calloc(ctx->header.n_tensors, sizeof(struct gguf_tensor_info));
-        if (!ctx->infos) {
-            fprintf(stderr, "%s: failed to allocate memory for tensor infos\n", __func__);
-            gguf_free(ctx);
-            return NULL;
-        }
-
-        for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
-            struct gguf_tensor_info * info = &ctx->infos[i];
-
-            for (int j = 0; j < GGML_MAX_DIMS; ++j) {
-                info->ne[j] = 1;
-            }
-
-            ok = ok && gguf_fread_str(file, &info->name,                          &offset);
-            ok = ok && gguf_fread_el (file, &info->n_dims, sizeof(info->n_dims),  &offset);
-
-            ok = ok && (info->n_dims <= GGML_MAX_DIMS);
-
-            for (uint32_t j = 0; j < info->n_dims; ++j) {
-                ok = ok && gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset);
-            }
-
-            ok = ok && gguf_fread_el (file, &info->type,   sizeof(info->type),    &offset);
-            ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset),  &offset);
-
-            ok = ok && gguf_tensor_info_sanitize(info);
-
-            // make sure there is no duplicated tensor names
-            for (uint64_t j = 0; j < i && ok; ++j) {
-                if (strcmp(info->name.data, ctx->infos[j].name.data) == 0) {
-                    fprintf(stderr, "%s: duplicated tensor name %s\n", __func__, info->name.data);
-                    ok = false;
-                }
-            }
-
-            if (!ok) {
-                fprintf(stderr, "%s: failed to read tensor info\n", __func__);
-                gguf_free(ctx);
-                return NULL;
-            }
-        }
-    }
-
-    ctx->alignment = GGUF_DEFAULT_ALIGNMENT;
-
-    int alignment_idx = gguf_find_key(ctx, "general.alignment");
-    if (alignment_idx != -1) {
-        ctx->alignment = gguf_get_val_u32(ctx, alignment_idx);
-    }
-
-    // we require the data section to be aligned, so take into account any padding
-    {
-        const size_t offset_pad = offset % ctx->alignment;
-
-        if (offset_pad != 0) {
-            offset += ctx->alignment - offset_pad;
-            fseek(file, offset, SEEK_SET);
-        }
-    }
-
-    // store the current file offset - this is where the data section starts
-    ctx->offset = offset;
-
-    // compute the total size of the data section, taking into account the alignment
-    {
-        ctx->size = 0;
-        for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
-            struct gguf_tensor_info * info = &ctx->infos[i];
-
-            const int64_t ne =
-                (int64_t) info->ne[0] *
-                (int64_t) info->ne[1] *
-                (int64_t) info->ne[2] *
-                (int64_t) info->ne[3];
-
-            if (ggml_blck_size(info->type) == 0 ) {
-                // this tensor type support have been removed:
-                fprintf(stderr, "%s: tensor '%s' of type %d: %s\n",
-                        __func__, info->name.data, (int) info->type, ggml_type_name(info->type));
-                gguf_free(ctx);
-                return NULL;
-            }
-
-            if (ne % ggml_blck_size(info->type) != 0) {
-                fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n",
-                        __func__, info->name.data, (int) info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type));
-                gguf_free(ctx);
-                return NULL;
-            }
-
-            const size_t size_cur = ggml_row_size(info->type, ne);
-
-            ctx->size += GGML_PAD(size_cur, ctx->alignment);
-        }
-    }
-
-    // load the tensor data only if requested
-    if (params.ctx != NULL) {
-        // if the provided gguf_context is no_alloc, then we create "empty" tensors and do not read the binary blob
-        // otherwise, we load the binary blob into the created ggml_context as well, and point the "data" members of
-        // the ggml_tensor structs to the appropriate locations in the binary blob
-
-        // compute the exact size needed for the new ggml_context
-        const size_t mem_size =
-            params.no_alloc ?
-            (ctx->header.n_tensors    )*ggml_tensor_overhead() :
-            (ctx->header.n_tensors + 1)*ggml_tensor_overhead() + ctx->size;
-
-        struct ggml_init_params pdata = {
-            .mem_size   = mem_size,
-            .mem_buffer = NULL,
-            .no_alloc   = params.no_alloc,
-        };
-
-        *params.ctx = ggml_init(pdata);
-        if (*params.ctx == NULL) {
-            fprintf(stderr, "%s: failed to initialize context\n", __func__);
-            gguf_free(ctx);
-            return NULL;
-        }
-
-        struct ggml_context * ctx_data = *params.ctx;
-
-        struct ggml_tensor * data = NULL;
-
-        if (!params.no_alloc) {
-            data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size);
-
-            ok = ok && data != NULL;
-
-            // read the binary blob with the tensor data
-            ok = ok && gguf_fread_el(file, data->data, ctx->size, &offset);
-
-            if (!ok) {
-                fprintf(stderr, "%s: failed to read tensor data\n", __func__);
-                ggml_free(ctx_data);
-                gguf_free(ctx);
-                return NULL;
-            }
-
-            ctx->data = data->data;
-        }
-
-        ggml_set_no_alloc(ctx_data, true);
-
-        // create the tensors
-        for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
-            const int64_t ne[GGML_MAX_DIMS] = {
-                ctx->infos[i].ne[0],
-                ctx->infos[i].ne[1],
-                ctx->infos[i].ne[2],
-                ctx->infos[i].ne[3],
-            };
-
-            struct ggml_tensor * cur = ggml_new_tensor(ctx_data, ctx->infos[i].type, ctx->infos[i].n_dims, ne);
-
-            ok = ok && cur != NULL;
-
-            if (!ok) {
-                break;
-            }
-
-            ggml_set_name(cur, ctx->infos[i].name.data);
-
-            // point the data member to the appropriate location in the binary blob using the tensor infos
-            if (!params.no_alloc) {
-              //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file
-                cur->data = (char *) data->data + ctx->infos[i].offset;               // offset from data
-            }
-        }
-
-        if (!ok) {
-            fprintf(stderr, "%s: failed to read the tensor data\n", __func__);
-            ggml_free(ctx_data);
-            gguf_free(ctx);
-            return NULL;
-        }
-
-        ggml_set_no_alloc(ctx_data, params.no_alloc);
-    }
-
-    return ctx;
-}
-
-struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) {
-    FILE * file = ggml_fopen(fname, "rb");
-    if (!file) {
-        fprintf(stderr, "%s: failed to open '%s': '%s'\n", __func__, fname, strerror(errno));
-        return NULL;
-    }
-
-    struct gguf_context * result = gguf_init_from_file_impl(file, params);
-    fclose(file);
-    return result;
-}
-
-void gguf_free(struct gguf_context * ctx) {
-    if (ctx == NULL) {
-        return;
-    }
-
-    if (ctx->kv) {
-        // free string memory - not great..
-        for (uint64_t i = 0; i < ctx->header.n_kv; ++i) {
-            gguf_free_kv(&ctx->kv[i]);
-        }
-
-        GGML_FREE(ctx->kv);
-    }
-
-    if (ctx->infos) {
-        for (uint64_t i = 0; i < ctx->header.n_tensors; ++i) {
-            struct gguf_tensor_info * info = &ctx->infos[i];
-
-            if (info->name.data) {
-                GGML_FREE(info->name.data);
-            }
-        }
-
-        GGML_FREE(ctx->infos);
-    }
-
-    GGML_FREE(ctx);
-}
-
-const char * gguf_type_name(enum gguf_type type) {
-    return GGUF_TYPE_NAME[type];
-}
-
-int gguf_get_version(const struct gguf_context * ctx) {
-    return ctx->header.version;
-}
-
-size_t gguf_get_alignment(const struct gguf_context * ctx) {
-    return ctx->alignment;
-}
-
-size_t gguf_get_data_offset(const struct gguf_context * ctx) {
-    return ctx->offset;
-}
-
-void * gguf_get_data(const struct gguf_context * ctx) {
-    return ctx->data;
-}
-
-int gguf_get_n_kv(const struct gguf_context * ctx) {
-    return ctx->header.n_kv;
-}
-
-int gguf_find_key(const struct gguf_context * ctx, const char * key) {
-    // return -1 if key not found
-    int keyfound = -1;
-
-    const int n_kv = gguf_get_n_kv(ctx);
-
-    for (int i = 0; i < n_kv; ++i) {
-        if (strcmp(key, gguf_get_key(ctx, i)) == 0) {
-            keyfound = i;
-            break;
-        }
-    }
-
-    return keyfound;
-}
-
-const char * gguf_get_key(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    return ctx->kv[key_id].key.data;
-}
-
-enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    return ctx->kv[key_id].type;
-}
-
-enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
-    return ctx->kv[key_id].value.arr.type;
-}
-
-const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
-    return ctx->kv[key_id].value.arr.data;
-}
-
-const char * gguf_get_arr_str(const struct gguf_context * ctx, int key_id, int i) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
-    struct gguf_kv * kv = &ctx->kv[key_id];
-    struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i];
-    return str->data;
-}
-
-int gguf_get_arr_n(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_ARRAY);
-    return ctx->kv[key_id].value.arr.n;
-}
-
-uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT8);
-    return ctx->kv[key_id].value.uint8;
-}
-
-int8_t gguf_get_val_i8(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT8);
-    return ctx->kv[key_id].value.int8;
-}
-
-uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT16);
-    return ctx->kv[key_id].value.uint16;
-}
-
-int16_t gguf_get_val_i16(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT16);
-    return ctx->kv[key_id].value.int16;
-}
-
-uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT32);
-    return ctx->kv[key_id].value.uint32;
-}
-
-int32_t gguf_get_val_i32(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT32);
-    return ctx->kv[key_id].value.int32;
-}
-
-float gguf_get_val_f32(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT32);
-    return ctx->kv[key_id].value.float32;
-}
-
-uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_UINT64);
-    return ctx->kv[key_id].value.uint64;
-}
-
-int64_t gguf_get_val_i64(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_INT64);
-    return ctx->kv[key_id].value.int64;
-}
-
-double gguf_get_val_f64(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_FLOAT64);
-    return ctx->kv[key_id].value.float64;
-}
-
-bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_BOOL);
-    return ctx->kv[key_id].value.bool_;
-}
-
-const char * gguf_get_val_str(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type == GGUF_TYPE_STRING);
-    return ctx->kv[key_id].value.str.data;
-}
-
-const void * gguf_get_val_data(const struct gguf_context * ctx, int key_id) {
-    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
-    GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_ARRAY);
-    GGML_ASSERT(ctx->kv[key_id].type != GGUF_TYPE_STRING);
-    return &ctx->kv[key_id].value;
-}
-
-int gguf_get_n_tensors(const struct gguf_context * ctx) {
-    return ctx->header.n_tensors;
-}
-
-int gguf_find_tensor(const struct gguf_context * ctx, const char * name) {
-    // return -1 if tensor not found
-    int tensorfound = -1;
-
-    const int n_tensors = gguf_get_n_tensors(ctx);
-
-    for (int i = 0; i < n_tensors; ++i) {
-        if (strcmp(name, gguf_get_tensor_name(ctx, i)) == 0) {
-            tensorfound = i;
-            break;
-        }
-    }
-
-    return tensorfound;
-}
-
-size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i) {
-    return ctx->infos[i].offset;
-}
-
-char * gguf_get_tensor_name(const struct gguf_context * ctx, int i) {
-    return ctx->infos[i].name.data;
-}
-
-enum ggml_type gguf_get_tensor_type(const struct gguf_context * ctx, int i) {
-    return ctx->infos[i].type;
-}
-
-// returns the index
-static int gguf_get_or_add_key(struct gguf_context * ctx, const char * key) {
-    const int idx = gguf_find_key(ctx, key);
-    if (idx >= 0) {
-        return idx;
-    }
-
-    const int n_kv = gguf_get_n_kv(ctx);
-
-    ctx->kv = realloc(ctx->kv, (n_kv + 1) * sizeof(struct gguf_kv));
-    ctx->kv[n_kv].key.n    = strlen(key);
-    ctx->kv[n_kv].key.data = strdup(key);
-    ctx->header.n_kv++;
-
-    return n_kv;
-}
-
-void gguf_remove_key(struct gguf_context * ctx, const char * key) {
-    const int idx = gguf_find_key(ctx, key);
-    if (idx >= 0) {
-        const int n_kv = gguf_get_n_kv(ctx);
-        gguf_free_kv(&ctx->kv[idx]);
-        for (int i = idx; i < n_kv-1; ++i) {
-            ctx->kv[i] = ctx->kv[i+1];
-        }
-        ctx->kv = realloc(ctx->kv, (n_kv - 1) * sizeof(struct gguf_kv));
-        ctx->header.n_kv--;
-    }
-}
-
-void gguf_set_val_u8(struct gguf_context * ctx, const char * key, uint8_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type        = GGUF_TYPE_UINT8;
-    ctx->kv[idx].value.uint8 = val;
-}
-
-void gguf_set_val_i8(struct gguf_context * ctx, const char * key, int8_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type       = GGUF_TYPE_INT8;
-    ctx->kv[idx].value.int8 = val;
-}
-
-void gguf_set_val_u16(struct gguf_context * ctx, const char * key, uint16_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type         = GGUF_TYPE_UINT16;
-    ctx->kv[idx].value.uint16 = val;
-}
-
-void gguf_set_val_i16(struct gguf_context * ctx, const char * key, int16_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type        = GGUF_TYPE_INT16;
-    ctx->kv[idx].value.int16 = val;
-}
-
-void gguf_set_val_u32(struct gguf_context * ctx, const char * key, uint32_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type         = GGUF_TYPE_UINT32;
-    ctx->kv[idx].value.uint32 = val;
-}
-
-void gguf_set_val_i32(struct gguf_context * ctx, const char * key, int32_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type        = GGUF_TYPE_INT32;
-    ctx->kv[idx].value.int32 = val;
-}
-
-void gguf_set_val_f32(struct gguf_context * ctx, const char * key, float val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type          = GGUF_TYPE_FLOAT32;
-    ctx->kv[idx].value.float32 = val;
-}
-
-void gguf_set_val_u64(struct gguf_context * ctx, const char * key, uint64_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type         = GGUF_TYPE_UINT64;
-    ctx->kv[idx].value.uint64 = val;
-}
-
-void gguf_set_val_i64(struct gguf_context * ctx, const char * key, int64_t val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type        = GGUF_TYPE_INT64;
-    ctx->kv[idx].value.int64 = val;
-}
-
-void gguf_set_val_f64(struct gguf_context * ctx, const char * key, double val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type          = GGUF_TYPE_FLOAT64;
-    ctx->kv[idx].value.float64 = val;
-}
-
-void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type        = GGUF_TYPE_BOOL;
-    ctx->kv[idx].value.bool_ = val;
-}
-
-void gguf_set_val_str(struct gguf_context * ctx, const char * key, const char * val) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type           = GGUF_TYPE_STRING;
-    ctx->kv[idx].value.str.n    = strlen(val);
-    ctx->kv[idx].value.str.data = strdup(val);
-}
-
-void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type           = GGUF_TYPE_ARRAY;
-    ctx->kv[idx].value.arr.type = type;
-    ctx->kv[idx].value.arr.n    = n;
-    ctx->kv[idx].value.arr.data = GGML_CALLOC(n, gguf_type_size(type));
-    memcpy(ctx->kv[idx].value.arr.data, data, n*gguf_type_size(type));
-}
-
-void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char ** data, int n) {
-    const int idx = gguf_get_or_add_key(ctx, key);
-
-    ctx->kv[idx].type           = GGUF_TYPE_ARRAY;
-    ctx->kv[idx].value.arr.type = GGUF_TYPE_STRING;
-    ctx->kv[idx].value.arr.n    = n;
-    ctx->kv[idx].value.arr.data = GGML_CALLOC(n, sizeof(struct gguf_str));
-    for (int i = 0; i < n; i++) {
-        struct gguf_str * str = &((struct gguf_str *)ctx->kv[idx].value.arr.data)[i];
-        str->n    = strlen(data[i]);
-        str->data = strdup(data[i]);
-    }
-}
-
-// set or add KV pairs from another context
-void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) {
-    for (uint32_t i = 0; i < src->header.n_kv; i++) {
-        switch (src->kv[i].type) {
-            case GGUF_TYPE_UINT8:   gguf_set_val_u8  (ctx, src->kv[i].key.data, src->kv[i].value.uint8);    break;
-            case GGUF_TYPE_INT8:    gguf_set_val_i8  (ctx, src->kv[i].key.data, src->kv[i].value.int8);     break;
-            case GGUF_TYPE_UINT16:  gguf_set_val_u16 (ctx, src->kv[i].key.data, src->kv[i].value.uint16);   break;
-            case GGUF_TYPE_INT16:   gguf_set_val_i16 (ctx, src->kv[i].key.data, src->kv[i].value.int16);    break;
-            case GGUF_TYPE_UINT32:  gguf_set_val_u32 (ctx, src->kv[i].key.data, src->kv[i].value.uint32);   break;
-            case GGUF_TYPE_INT32:   gguf_set_val_i32 (ctx, src->kv[i].key.data, src->kv[i].value.int32);    break;
-            case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (ctx, src->kv[i].key.data, src->kv[i].value.float32);  break;
-            case GGUF_TYPE_UINT64:  gguf_set_val_u64 (ctx, src->kv[i].key.data, src->kv[i].value.uint64);   break;
-            case GGUF_TYPE_INT64:   gguf_set_val_i64 (ctx, src->kv[i].key.data, src->kv[i].value.int64);    break;
-            case GGUF_TYPE_FLOAT64: gguf_set_val_f64 (ctx, src->kv[i].key.data, src->kv[i].value.float64);  break;
-            case GGUF_TYPE_BOOL:    gguf_set_val_bool(ctx, src->kv[i].key.data, src->kv[i].value.bool_);    break;
-            case GGUF_TYPE_STRING:  gguf_set_val_str (ctx, src->kv[i].key.data, src->kv[i].value.str.data); break;
-            case GGUF_TYPE_ARRAY:
-                {
-                    if (src->kv[i].value.arr.type == GGUF_TYPE_STRING) {
-                        const char ** data = GGML_CALLOC(src->kv[i].value.arr.n, sizeof(char *));
-                        for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) {
-                            data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data;
-                        }
-                        gguf_set_arr_str(ctx, src->kv[i].key.data, data, src->kv[i].value.arr.n);
-                        GGML_FREE((void *)data);
-                    } else if (src->kv[i].value.arr.type == GGUF_TYPE_ARRAY) {
-                        GGML_ABORT("nested arrays not supported");
-                    } else {
-                        gguf_set_arr_data(ctx, src->kv[i].key.data, src->kv[i].value.arr.type, src->kv[i].value.arr.data, src->kv[i].value.arr.n);
-                    }
-                } break;
-            default: GGML_ABORT("invalid type");
-        }
-    }
-}
-
-void gguf_add_tensor(
-             struct gguf_context * ctx,
-        const struct ggml_tensor * tensor) {
-    GGML_ASSERT(tensor);
-    if (gguf_find_tensor(ctx, tensor->name) != -1) {
-        GGML_ABORT("duplicated tensor name");
-    }
-
-    const int idx = ctx->header.n_tensors;
-    ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info));
-
-    ctx->infos[idx].name.n    = strlen(tensor->name);
-    ctx->infos[idx].name.data = strdup(tensor->name);
-
-    for (int i = 0; i < GGML_MAX_DIMS; ++i) {
-        ctx->infos[idx].ne[i] = 1;
-    }
-
-    ctx->infos[idx].n_dims = ggml_n_dims(tensor);
-    for (uint32_t i = 0; i < ctx->infos[idx].n_dims; i++) {
-        ctx->infos[idx].ne[i] = tensor->ne[i];
-    }
-
-    ctx->infos[idx].type   = tensor->type;
-    ctx->infos[idx].offset = 0;
-    ctx->infos[idx].data   = tensor->data;
-    ctx->infos[idx].size   = ggml_nbytes(tensor);
-
-    if (ctx->header.n_tensors > 0) {
-        ctx->infos[idx].offset = ctx->infos[idx - 1].offset + GGML_PAD(ctx->infos[idx - 1].size, ctx->alignment);
-    }
-
-    ctx->header.n_tensors++;
-}
-
-void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type) {
-    const int idx = gguf_find_tensor(ctx, name);
-    if (idx < 0) {
-        GGML_ABORT("tensor not found");
-    }
-
-    ctx->infos[idx].type = type;
-}
-
-void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size) {
-    const int idx = gguf_find_tensor(ctx, name);
-    if (idx < 0) {
-        GGML_ABORT("tensor not found");
-    }
-
-    ctx->infos[idx].data = data;
-    ctx->infos[idx].size = size;
-
-    // update offsets
-    for (uint32_t i = idx + 1; i < ctx->header.n_tensors; ++i) {
-        ctx->infos[i].offset = ctx->infos[i - 1].offset + GGML_PAD(ctx->infos[i - 1].size, ctx->alignment);
-    }
-}
-
-//static void gguf_fwrite_str(FILE * file, const struct gguf_str * val) {
-//    fwrite(&val->n,   sizeof(val->n),    1, file);
-//    fwrite(val->data, sizeof(char), val->n, file);
-//}
-//
-//static void gguf_fwrite_el(FILE * file, const void * val, size_t size) {
-//    fwrite(val, sizeof(char), size, file);
-//}
-
-struct gguf_buf gguf_buf_init(size_t size) {
-    struct gguf_buf buf = {
-        /*buf.data   =*/ size == 0 ? NULL : GGML_CALLOC(1, size),
-        /*buf.size   =*/ size,
-        /*buf.offset =*/ 0,
-    };
-
-    return buf;
-}
-
-void gguf_buf_free(struct gguf_buf buf) {
-    if (buf.data) {
-        GGML_FREE(buf.data);
-    }
-}
-
-static void gguf_buf_grow(struct gguf_buf * buf, size_t size) {
-    if (buf->offset + size > buf->size) {
-        buf->size = 1.5*(buf->offset + size);
-        if (buf->data) {
-            buf->data = realloc(buf->data, buf->size);
-        }
-    }
-}
-
-static void gguf_bwrite_str(struct gguf_buf * buf, const struct gguf_str * val) {
-    gguf_buf_grow(buf, sizeof(val->n) + val->n);
-
-    if (buf->data) {
-        memcpy((char *) buf->data + buf->offset, &val->n, sizeof(val->n));
-    }
-    buf->offset += sizeof(val->n);
-
-    if (buf->data) {
-        memcpy((char *) buf->data + buf->offset, val->data, val->n);
-    }
-    buf->offset += val->n;
-}
-
-static void gguf_bwrite_el(struct gguf_buf * buf, const void * val, size_t el_size) {
-    gguf_buf_grow(buf, el_size);
-
-    if (buf->data) {
-        memcpy((char *) buf->data + buf->offset, val, el_size);
-    }
-    buf->offset += el_size;
-}
-
-void gguf_write_to_buf(const struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) {
-    // write header
-    gguf_bwrite_el(buf, &ctx->header.magic,     sizeof(ctx->header.magic));
-    gguf_bwrite_el(buf, &ctx->header.version,   sizeof(ctx->header.version));
-    gguf_bwrite_el(buf, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors));
-    gguf_bwrite_el(buf, &ctx->header.n_kv,      sizeof(ctx->header.n_kv));
-
-    // write key-value pairs
-    for (uint32_t i = 0; i < ctx->header.n_kv; ++i) {
-        struct gguf_kv * kv = &ctx->kv[i];
-
-        gguf_bwrite_str(buf, &kv->key);
-        gguf_bwrite_el (buf, &kv->type, sizeof(kv->type));
-
-        switch (kv->type) {
-            case GGUF_TYPE_UINT8:   gguf_bwrite_el( buf, &kv->value.uint8,   sizeof(kv->value.uint8)  ); break;
-            case GGUF_TYPE_INT8:    gguf_bwrite_el (buf, &kv->value.int8,    sizeof(kv->value.int8)   ); break;
-            case GGUF_TYPE_UINT16:  gguf_bwrite_el (buf, &kv->value.uint16,  sizeof(kv->value.uint16) ); break;
-            case GGUF_TYPE_INT16:   gguf_bwrite_el (buf, &kv->value.int16,   sizeof(kv->value.int16)  ); break;
-            case GGUF_TYPE_UINT32:  gguf_bwrite_el (buf, &kv->value.uint32,  sizeof(kv->value.uint32) ); break;
-            case GGUF_TYPE_INT32:   gguf_bwrite_el (buf, &kv->value.int32,   sizeof(kv->value.int32)  ); break;
-            case GGUF_TYPE_FLOAT32: gguf_bwrite_el (buf, &kv->value.float32, sizeof(kv->value.float32)); break;
-            case GGUF_TYPE_UINT64:  gguf_bwrite_el (buf, &kv->value.uint64,  sizeof(kv->value.uint64) ); break;
-            case GGUF_TYPE_INT64:   gguf_bwrite_el (buf, &kv->value.int64,   sizeof(kv->value.int64)  ); break;
-            case GGUF_TYPE_FLOAT64: gguf_bwrite_el (buf, &kv->value.float64, sizeof(kv->value.float64)); break;
-            case GGUF_TYPE_BOOL:    gguf_bwrite_el (buf, &kv->value.bool_,   sizeof(kv->value.bool_)  ); break;
-            case GGUF_TYPE_STRING:  gguf_bwrite_str(buf, &kv->value.str                               ); break;
-            case GGUF_TYPE_ARRAY:
-                {
-                    gguf_bwrite_el(buf, &kv->value.arr.type, sizeof(kv->value.arr.type));
-                    gguf_bwrite_el(buf, &kv->value.arr.n,    sizeof(kv->value.arr.n)   );
-
-                    switch (kv->value.arr.type) {
-                        case GGUF_TYPE_UINT8:
-                        case GGUF_TYPE_INT8:
-                        case GGUF_TYPE_UINT16:
-                        case GGUF_TYPE_INT16:
-                        case GGUF_TYPE_UINT32:
-                        case GGUF_TYPE_INT32:
-                        case GGUF_TYPE_FLOAT32:
-                        case GGUF_TYPE_UINT64:
-                        case GGUF_TYPE_INT64:
-                        case GGUF_TYPE_FLOAT64:
-                        case GGUF_TYPE_BOOL:
-                            {
-                                gguf_bwrite_el(buf, kv->value.arr.data, kv->value.arr.n * gguf_type_size(kv->value.arr.type));
-                            } break;
-                        case GGUF_TYPE_STRING:
-                            {
-                                for (uint32_t j = 0; j < kv->value.arr.n; ++j) {
-                                    gguf_bwrite_str(buf, &((struct gguf_str *) kv->value.arr.data)[j]);
-                                }
-                            } break;
-                        case GGUF_TYPE_ARRAY:
-                        default: GGML_ABORT("invalid type");
-                    }
-                } break;
-            default: GGML_ABORT("invalid type");
-        }
-    }
-
-    // write tensor infos
-    for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
-        struct gguf_tensor_info * info = &ctx->infos[i];
-
-        gguf_bwrite_str(buf, &info->name);
-        gguf_bwrite_el (buf, &info->n_dims, sizeof(info->n_dims));
-        for (uint32_t j = 0; j < info->n_dims; ++j) {
-            gguf_bwrite_el(buf, &info->ne[j], sizeof(info->ne[j]));
-        }
-        gguf_bwrite_el(buf, &info->type,   sizeof(info->type));
-        gguf_bwrite_el(buf, &info->offset, sizeof(info->offset));
-    }
-
-    // we require the data section to be aligned, so take into account any padding
-    {
-        const size_t offset     = buf->offset;
-        const size_t offset_pad = GGML_PAD(offset, ctx->alignment);
-
-        if (offset_pad != offset) {
-            uint8_t pad = 0;
-            for (size_t i = 0; i < offset_pad - offset; ++i) {
-                gguf_bwrite_el(buf, &pad, sizeof(pad));
-            }
-        }
-    }
-
-    if (only_meta) {
-        return;
-    }
-
-    size_t offset = 0;
-
-    // write tensor data
-    for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) {
-        struct gguf_tensor_info * info = &ctx->infos[i];
-
-        const size_t size     = info->size;
-        const size_t size_pad = GGML_PAD(size, ctx->alignment);
-
-        gguf_bwrite_el(buf, info->data, size);
-
-        if (size_pad != size) {
-            uint8_t pad = 0;
-            for (size_t j = 0; j < size_pad - size; ++j) {
-                gguf_bwrite_el(buf, &pad, sizeof(pad));
-            }
-        }
-
-        GGML_ASSERT(offset == info->offset);
-
-        offset += size_pad;
-    }
-}
-
-void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) {
-    FILE * file = ggml_fopen(fname, "wb");
-    if (!file) {
-        GGML_ABORT("failed to open file for writing");
-    }
-
-    struct gguf_buf buf = gguf_buf_init(16*1024);
-
-    gguf_write_to_buf(ctx, &buf, only_meta);
-
-    fwrite(buf.data, 1, buf.offset, file);
-
-    gguf_buf_free(buf);
-
-    fclose(file);
-}
-
-size_t gguf_get_meta_size(const struct gguf_context * ctx) {
-    // no allocs - only compute size
-    struct gguf_buf buf = gguf_buf_init(0);
-
-    gguf_write_to_buf(ctx, &buf, true);
-
-    return buf.offset;
-}
-
-void gguf_get_meta_data(const struct gguf_context * ctx, void * data) {
-    struct gguf_buf buf = gguf_buf_init(16*1024);
-
-    gguf_write_to_buf(ctx, &buf, true);
-
-    memcpy(data, buf.data, buf.offset);
-
-    gguf_buf_free(buf);
-}
-
 void ggml_log_set(ggml_log_callback log_callback, void * user_data) {
     g_logger_state.log_callback = log_callback ? log_callback : ggml_log_callback_default;
     g_logger_state.log_callback_user_data = user_data;
diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp
new file mode 100644
index 00000000000..655ed600a17
--- /dev/null
+++ b/ggml/src/gguf.cpp
@@ -0,0 +1,1325 @@
+#include "ggml.h"
+#include "ggml-backend.h"
+#include "ggml-impl.h"
+#include "gguf.h"
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+template 
+struct type_to_gguf_type;
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_UINT8;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_INT8;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_UINT16;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_INT16;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_UINT32;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_INT32;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_FLOAT32;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_BOOL;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_STRING;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_UINT64;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_INT64;
+};
+
+template <>
+struct type_to_gguf_type {
+    static constexpr enum gguf_type value = GGUF_TYPE_FLOAT64;
+};
+
+static const std::map GGUF_TYPE_SIZE = {
+    {GGUF_TYPE_UINT8,   sizeof(uint8_t)},
+    {GGUF_TYPE_INT8,    sizeof(int8_t)},
+    {GGUF_TYPE_UINT16,  sizeof(uint16_t)},
+    {GGUF_TYPE_INT16,   sizeof(int16_t)},
+    {GGUF_TYPE_UINT32,  sizeof(uint32_t)},
+    {GGUF_TYPE_INT32,   sizeof(int32_t)},
+    {GGUF_TYPE_FLOAT32, sizeof(float)},
+    {GGUF_TYPE_BOOL,    sizeof(int8_t)},
+    {GGUF_TYPE_STRING,  0}, // undefined
+    {GGUF_TYPE_ARRAY,   0}, // undefined
+    {GGUF_TYPE_UINT64,  sizeof(uint64_t)},
+    {GGUF_TYPE_INT64,   sizeof(int64_t)},
+    {GGUF_TYPE_FLOAT64, sizeof(double)},
+};
+static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
+
+static const std::map GGUF_TYPE_NAME = {
+    {GGUF_TYPE_UINT8,   "u8"},
+    {GGUF_TYPE_INT8,    "i8"},
+    {GGUF_TYPE_UINT16,  "u16"},
+    {GGUF_TYPE_INT16,   "i16"},
+    {GGUF_TYPE_UINT32,  "u32"},
+    {GGUF_TYPE_INT32,   "i32"},
+    {GGUF_TYPE_FLOAT32, "f32"},
+    {GGUF_TYPE_BOOL,    "bool"},
+    {GGUF_TYPE_STRING,  "str"},
+    {GGUF_TYPE_ARRAY,   "arr"},
+    {GGUF_TYPE_UINT64,  "u64"},
+    {GGUF_TYPE_INT64,   "i64"},
+    {GGUF_TYPE_FLOAT64, "f64"},
+};
+static_assert(GGUF_TYPE_COUNT == 13, "GGUF_TYPE_COUNT != 13");
+
+size_t gguf_type_size(enum gguf_type type) {
+    auto it = GGUF_TYPE_SIZE.find(type);
+    return it == GGUF_TYPE_SIZE.end() ? 0 : it->second;
+}
+
+struct gguf_kv {
+    std::string key;
+
+    bool is_array;
+    enum gguf_type type;
+
+    std::vector      data;
+    std::vector data_string;
+
+    template 
+    gguf_kv(const std::string & key, const T value)
+            : key(key), is_array(false), type(type_to_gguf_type::value) {
+        GGML_ASSERT(!key.empty());
+        data.resize(sizeof(T));
+        memcpy(data.data(), &value, sizeof(T));
+    }
+
+    template 
+    gguf_kv(const std::string & key, const std::vector & value)
+            : key(key), is_array(true), type(type_to_gguf_type::value) {
+        GGML_ASSERT(!key.empty());
+        data.resize(value.size()*sizeof(T));
+        for (size_t i = 0; i < value.size(); ++i) {
+            const T tmp = value[i];
+            memcpy(data.data() + i*sizeof(T), &tmp, sizeof(T));
+        }
+    }
+
+    gguf_kv(const std::string & key, const std::string & value)
+            : key(key), is_array(false), type(GGUF_TYPE_STRING) {
+        GGML_ASSERT(!key.empty());
+        data_string.push_back(value);
+    }
+
+    gguf_kv(const std::string & key, const std::vector & value)
+            : key(key), is_array(true), type(GGUF_TYPE_STRING) {
+        GGML_ASSERT(!key.empty());
+        data_string = value;
+    }
+
+    const std::string & get_key() const {
+        return key;
+    }
+
+    const enum gguf_type & get_type() const {
+        return type;
+    }
+
+    size_t get_ne() const {
+        if (type == GGUF_TYPE_STRING) {
+            const size_t ne = data_string.size();
+            GGML_ASSERT(is_array || ne == 1);
+            return ne;
+        }
+        const size_t type_size = gguf_type_size(type);
+        GGML_ASSERT(data.size() % type_size == 0);
+        const size_t ne = data.size() / type_size;
+        GGML_ASSERT(is_array || ne == 1);
+        return ne;
+    }
+
+    template 
+    const T & get_val(const size_t i = 0) const {
+        GGML_ASSERT(type_to_gguf_type::value == type);
+        if constexpr (std::is_same::value) {
+            GGML_ASSERT(data_string.size() >= i+1);
+            return data_string[i];
+        }
+        const size_t type_size = gguf_type_size(type);
+        GGML_ASSERT(data.size() % type_size == 0);
+        GGML_ASSERT(data.size() >= (i+1)*type_size);
+        return reinterpret_cast(data.data())[i];
+    }
+
+    void cast(const enum gguf_type new_type) {
+        const size_t new_type_size = gguf_type_size(new_type);
+        GGML_ASSERT(data.size() % new_type_size == 0);
+        type = new_type;
+    }
+};
+
+struct gguf_tensor_info {
+    struct ggml_tensor t; // for holding the equivalent info
+    uint64_t offset;      // offset from start of `data`, must be a multiple of `ALIGNMENT`
+};
+
+struct gguf_context {
+    uint32_t version = GGUF_VERSION;
+
+    std::vector kv;
+    std::vector info;
+
+    size_t alignment = GGUF_DEFAULT_ALIGNMENT;
+    size_t offset    = 0; // offset of `data` from beginning of file
+    size_t size      = 0; // size of `data` in bytes
+
+    void * data = nullptr;
+};
+
+struct gguf_reader {
+    FILE * file;
+
+    gguf_reader(FILE * file) : file(file) {}
+
+    template 
+    bool read(T & dst) const {
+        return fread(&dst, 1, sizeof(dst), file) == sizeof(dst);
+    }
+
+    template 
+    bool read(std::vector & dst, const size_t n) const {
+        dst.resize(n);
+        for (size_t i = 0; i < dst.size(); ++i) {
+            if constexpr (std::is_same::value) {
+                bool tmp;
+                if (!read(tmp)) {
+                    return false;
+                }
+                dst[i] = tmp;
+            } else {
+                if (!read(dst[i])) {
+                    return false;
+                }
+            }
+        }
+        return true;
+    }
+
+    bool read(bool & dst) const {
+        int8_t tmp = -1;
+        if (!read(tmp)) {
+            return false;
+        }
+        dst = tmp != 0;
+        return true;
+    }
+
+    bool read(enum ggml_type & dst) const {
+        int32_t tmp = -1;
+        if (!read(tmp)) {
+            return false;
+        }
+        dst = ggml_type(tmp);
+        return true;
+    }
+
+    bool read(enum gguf_type & dst) const {
+        int32_t tmp = -1;
+        if (!read(tmp)) {
+            return false;
+        }
+        dst = gguf_type(tmp);
+        return true;
+    }
+
+    bool read(std::string & dst) const {
+        uint64_t size = -1;
+        if (!read(size)) {
+            return false;
+        }
+        dst.resize(size);
+        return fread(dst.data(), 1, dst.length(), file) == dst.length();
+    }
+
+    bool read(void * dst, const size_t size) const {
+        return fread(dst, 1, size, file) == size;
+    }
+};
+
+struct gguf_context * gguf_init_empty(void) {
+    return new gguf_context;
+}
+
+template
+bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vector & kv, const std::string & key, const bool is_array, const size_t n) {
+    if (is_array) {
+        std::vector value;
+        try {
+            if (!gr.read(value, n)) {
+                return false;
+            }
+        } catch (std::length_error &) {
+            fprintf(stderr, "%s: encountered length_error while reading value for key '%s'\n", __func__, key.c_str());
+            return false;
+        } catch (std::bad_alloc &) {
+            fprintf(stderr, "%s: encountered bad_alloc error while reading value for key '%s'\n", __func__, key.c_str());
+            return false;
+        }
+        kv.emplace_back(key, value);
+    } else {
+        T value;
+        if (!gr.read(value)) {
+            return false;
+        }
+        kv.emplace_back(key, value);
+    }
+    return true;
+}
+
+struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params) {
+    const struct gguf_reader gr(file);
+    struct gguf_context * ctx = new gguf_context;
+
+    bool ok = true;
+
+    // file magic
+    {
+        std::vector magic;
+        ok = ok && gr.read(magic, 4);
+
+        if (!ok) {
+            fprintf(stderr, "%s: failed to read magic\n", __func__);
+            gguf_free(ctx);
+            return nullptr;
+        }
+
+        for (uint32_t i = 0; i < magic.size(); i++) {
+            if (magic[i] != GGUF_MAGIC[i]) {
+                fprintf(stderr, "%s: invalid magic characters: '%c%c%c%c', expected 'GGUF'\n", __func__, magic[0], magic[1], magic[2], magic[3]);
+                gguf_free(ctx);
+                return nullptr;
+            }
+        }
+    }
+
+    // header
+    int64_t n_kv      = 0;
+    int64_t n_tensors = 0;
+
+    if (ok && gr.read(ctx->version)) {
+        if (ctx->version == 1) {
+            fprintf(stderr, "%s: GGUFv1 is no longer supported, please use a more up-to-date version\n", __func__);
+            ok = false;
+        }
+        if (ctx->version > GGUF_VERSION) {
+            fprintf(stderr, "%s: this GGUF file is version %" PRIu32 " but this software only supports up to version %d\n",
+                __func__, ctx->version, GGUF_VERSION);
+            ok = false;
+        }
+    } else {
+        ok = false;
+    }
+
+    if (ok && gr.read(n_tensors)) {
+        static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing");
+        if (n_tensors < 0 || n_tensors > int64_t(SIZE_MAX/sizeof(gguf_tensor_info))) {
+            fprintf(stderr, "%s: number of tensors is %" PRIi64 " but must be in [0, %zu]\n",
+                __func__, n_tensors, SIZE_MAX/sizeof(gguf_tensor_info));
+            ok = false;
+        }
+    } else {
+        ok = false;
+    }
+
+    if (ok && gr.read(n_kv)) {
+        static_assert(sizeof(size_t) <= 8 && sizeof(gguf_tensor_info) >= 2, "int64_t insufficient for indexing");
+        if (n_kv < 0 || n_kv > int64_t(SIZE_MAX/sizeof(gguf_kv))) {
+            fprintf(stderr, "%s: number of key value pairs is %" PRIi64 " but must be in [0, %zu]\n",
+                    __func__, n_kv, SIZE_MAX/sizeof(gguf_kv));
+            ok = false;
+        }
+    } else {
+        ok = false;
+    }
+
+    if (!ok) {
+        fprintf(stderr, "%s: failed to read header\n", __func__);
+        gguf_free(ctx);
+        return nullptr;
+    }
+
+    // KV pairs
+    {
+        for (int64_t i = 0; ok && i < n_kv; ++i) {
+            std::string key;
+            gguf_type   type     = gguf_type(-1);
+            bool        is_array = false;
+            uint64_t    n        = 1;
+
+            try {
+                ok = ok && gr.read(key);
+            } catch (std::length_error &) {
+                fprintf(stderr, "%s: encountered length_error while reading key %" PRIi64 "\n", __func__, i);
+                ok = false;
+            } catch (std::bad_alloc &) {
+                fprintf(stderr, "%s: encountered bad_alloc error while reading key %" PRIi64 "\n", __func__, i);
+                ok = false;
+            }
+            for (size_t j = 0; ok && j < ctx->kv.size(); ++j) {
+                if (key == ctx->kv[j].key) {
+                    fprintf(stderr, "%s: duplicate key '%s' for tensors %zu and %" PRIi64 " \n", __func__, key.c_str(), j, i);
+                    ok = false;
+                }
+            }
+            if (!ok) {
+                break;
+            }
+
+            ok = ok && gr.read(type);
+            if (type == GGUF_TYPE_ARRAY) {
+                is_array = true;
+                ok = ok && gr.read(type);
+                ok = ok && gr.read(n);
+            }
+            if (!ok) {
+                break;
+            }
+
+            switch (type) {
+                case GGUF_TYPE_UINT8:   ok = ok && gguf_read_emplace_helper    (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_INT8:    ok = ok && gguf_read_emplace_helper     (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_UINT16:  ok = ok && gguf_read_emplace_helper   (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_INT16:   ok = ok && gguf_read_emplace_helper    (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_UINT32:  ok = ok && gguf_read_emplace_helper   (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_INT32:   ok = ok && gguf_read_emplace_helper    (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_FLOAT32: ok = ok && gguf_read_emplace_helper      (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_BOOL:    ok = ok && gguf_read_emplace_helper       (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_STRING:  ok = ok && gguf_read_emplace_helper(gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_UINT64:  ok = ok && gguf_read_emplace_helper   (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_INT64:   ok = ok && gguf_read_emplace_helper    (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_FLOAT64: ok = ok && gguf_read_emplace_helper     (gr, ctx->kv, key, is_array, n); break;
+                case GGUF_TYPE_ARRAY:
+                default:
+                    {
+                        fprintf(stderr, "%s: key '%s' has invalid GGUF type %d\n", __func__, key.c_str(), type);
+                        ok = false;
+                    } break;
+            }
+        }
+
+        if (!ok) {
+            fprintf(stderr, "%s: failed to read key-value pairs\n", __func__);
+            gguf_free(ctx);
+            return nullptr;
+        }
+        GGML_ASSERT(int64_t(ctx->kv.size()) == n_kv);
+
+        const int alignment_idx = gguf_find_key(ctx, GGUF_KEY_GENERAL_ALIGNMENT);
+        ctx->alignment = alignment_idx == -1 ? GGUF_DEFAULT_ALIGNMENT : gguf_get_val_u32(ctx, alignment_idx);
+
+        if (ctx->alignment == 0 || (ctx->alignment & (ctx->alignment - 1)) != 0) {
+            fprintf(stderr, "%s: alignment %zu is not a power of 2\n", __func__, ctx->alignment);
+            gguf_free(ctx);
+            return nullptr;
+        }
+    }
+
+    // read the tensor info
+    for (int64_t i = 0; ok && i < n_tensors; ++i) {
+        struct gguf_tensor_info info;
+
+        // tensor name
+        {
+            std::string name;
+            try {
+                ok = ok && gr.read(name);
+            } catch (std::length_error &) {
+                fprintf(stderr, "%s: encountered length_error while reading tensor name %" PRIi64 "\n", __func__, i);
+                ok = false;
+            } catch (std::bad_alloc &) {
+                fprintf(stderr, "%s: encountered bad_alloc error while reading tensor name %" PRIi64 "\n", __func__, i);
+                ok = false;
+            }
+            if (name.length() >= GGML_MAX_NAME) {
+                fprintf(stderr, "%s: tensor name %" PRIi64 " is too long: %zu >= %d\n", __func__, i, name.length(), GGML_MAX_NAME);
+                ok = false;
+                break;
+            }
+            ggml_set_name(&info.t, name.c_str());
+
+            // make sure there are no duplicate tensor names
+            for (int64_t j = 0; ok && j < i; ++j) {
+                if (strcmp(info.t.name, ctx->info[j].t.name) == 0) {
+                    fprintf(stderr, "%s: duplicate tensor name '%s' for tensors %" PRIi64 " and %" PRIi64 "\n", __func__, info.t.name, j, i);
+                    ok = false;
+                    break;
+                }
+            }
+        }
+        if (!ok) {
+            break;
+        }
+
+        // tensor shape
+        {
+            uint32_t n_dims = -1;
+            ok = ok && gr.read(n_dims);
+            if (n_dims > GGML_MAX_DIMS) {
+                fprintf(stderr, "%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n",
+                    __func__, info.t.name, n_dims, GGML_MAX_DIMS);
+                ok = false;
+                break;
+            }
+            for (uint32_t j = 0; ok && j < GGML_MAX_DIMS; ++j) {
+                info.t.ne[j] = 1;
+                if (j < n_dims) {
+                    ok = ok && gr.read(info.t.ne[j]);
+                }
+
+                // check that all ne are non-negative
+                if (info.t.ne[j] < 0) {
+                    fprintf(stderr, "%s: tensor '%s' dimension %" PRIu32 " has invalid number of elements: %" PRIi64 " < 0\n",
+                        __func__, info.t.name, j, info.t.ne[j]);
+                    ok = false;
+                    break;
+                }
+            }
+
+            // check that the total number of elements is representable
+            if (ok && ((INT64_MAX/info.t.ne[1] <= info.t.ne[0]) ||
+                       (INT64_MAX/info.t.ne[2] <= info.t.ne[0]*info.t.ne[1]) ||
+                       (INT64_MAX/info.t.ne[3] <= info.t.ne[0]*info.t.ne[1]*info.t.ne[2]))) {
+
+                fprintf(stderr, "%s: total number of elements in tensor '%s' with shape "
+                    "(%" PRIi64 ", %" PRIi64 ", %" PRIi64 ", %" PRIi64 ") is >= %" PRIi64 "\n",
+                    __func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], INT64_MAX);
+                ok = false;
+                break;
+            }
+        }
+        if (!ok) {
+            break;
+        }
+
+        // tensor type
+        {
+            ok = ok && gr.read(info.t.type);
+
+            // check that tensor type is within defined range
+            if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) {
+                fprintf(stderr, "%s: tensor '%s' has invalid ggml type %d (%s)\n",
+                    __func__, info.t.name, info.t.type, ggml_type_name(info.t.type));
+                ok = false;
+                break;
+            }
+            const size_t  type_size = ggml_type_size(info.t.type);
+            const int64_t blck_size = ggml_blck_size(info.t.type);
+
+            // check that row size is divisible by block size
+            if (blck_size == 0 || info.t.ne[0] % blck_size != 0) {
+                fprintf(stderr, "%s: tensor '%s' of type %d (%s) has %" PRId64 " elements per row, "
+                    "not a multiple of block size (%" PRId64 ")\n",
+                    __func__, info.t.name, (int) info.t.type, ggml_type_name(info.t.type), info.t.ne[0], blck_size);
+                ok = false;
+                break;
+            }
+
+            // calculate byte offsets given the tensor shape and type
+            info.t.nb[0] = type_size;
+            info.t.nb[1] = info.t.nb[0]*(info.t.ne[0]/blck_size);
+            for (int j = 2; j < GGML_MAX_DIMS; ++j) {
+                info.t.nb[j] = info.t.nb[j - 1]*info.t.ne[j - 1];
+            }
+        }
+        if (!ok) {
+            break;
+        }
+
+        // tensor data offset within buffer
+        ok = ok && gr.read(info.offset);
+
+        ctx->info.push_back(info);
+    }
+
+    if (!ok) {
+        fprintf(stderr, "%s: failed to read tensor info\n", __func__);
+        gguf_free(ctx);
+        return nullptr;
+    }
+    GGML_ASSERT(int64_t(ctx->info.size()) == n_tensors);
+
+    // we require the data section to be aligned, so take into account any padding
+    if (fseek(file, GGML_PAD(ftell(file), ctx->alignment), SEEK_SET) != 0) {
+        fprintf(stderr, "%s: failed to seek to beginning of data section\n", __func__);
+        gguf_free(ctx);
+        return nullptr;
+    }
+
+    // store the current file offset - this is where the data section starts
+    ctx->offset = ftell(file);
+
+    // compute the total size of the data section, taking into account the alignment
+    {
+        ctx->size = 0;
+        for (size_t i = 0; i < ctx->info.size(); ++i) {
+            const gguf_tensor_info & ti = ctx->info[i];
+            if (ti.offset != ctx->size) {
+                fprintf(stderr, "%s: tensor '%s' has offset %" PRIu64 ", expected %zu\n",
+                    __func__, ti.t.name, ti.offset, ctx->size);
+                fprintf(stderr, "%s: failed to read tensor data\n", __func__);
+                gguf_free(ctx);
+                return nullptr;
+            }
+            ctx->size += GGML_PAD(ggml_nbytes(&ti.t), ctx->alignment);
+        }
+    }
+
+    // load the tensor data only if requested
+    if (params.ctx != nullptr) {
+        // if the provided gguf_context is no_alloc, then we create "empty" tensors and do not read the binary blob
+        // otherwise, we load the binary blob into the created ggml_context as well, and point the "data" members of
+        //   the ggml_tensor structs to the appropriate locations in the binary blob
+
+        // compute the exact size needed for the new ggml_context
+        const size_t mem_size =
+            params.no_alloc ?
+            (n_tensors    )*ggml_tensor_overhead() :
+            (n_tensors + 1)*ggml_tensor_overhead() + ctx->size;
+
+        struct ggml_init_params pdata = {
+            /*mem_size   =*/ mem_size,
+            /*mem_buffer =*/ nullptr,
+            /*no_alloc   =*/ params.no_alloc,
+        };
+
+        *params.ctx = ggml_init(pdata);
+        if (*params.ctx == nullptr) {
+            fprintf(stderr, "%s: failed to initialize ggml context for storing tensors\n", __func__);
+            gguf_free(ctx);
+            return nullptr;
+        }
+
+        struct ggml_context * ctx_data = *params.ctx;
+
+        struct ggml_tensor * data = nullptr;
+
+        if (!params.no_alloc) {
+            data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size);
+
+            ok = ok && data != nullptr;
+
+            // read the binary blob with the tensor data
+            ok = ok && gr.read(data->data, ctx->size);
+
+            if (!ok) {
+                fprintf(stderr, "%s: failed to read tensor data binary blob\n", __func__);
+                ggml_free(ctx_data);
+                *params.ctx = nullptr;
+                gguf_free(ctx);
+                return nullptr;
+            }
+
+            ctx->data = data->data;
+        }
+
+        ggml_set_no_alloc(ctx_data, true);
+
+        // create the tensors
+        for (size_t i = 0; i < ctx->info.size(); ++i) {
+            const struct gguf_tensor_info & info = ctx->info[i];
+
+            struct ggml_tensor * cur = ggml_new_tensor(ctx_data, info.t.type, GGML_MAX_DIMS, info.t.ne);
+
+            ok = ok && cur != nullptr;
+
+            if (!ok) {
+                break;
+            }
+
+            ggml_set_name(cur, info.t.name);
+
+            // point the data member to the appropriate location in the binary blob using the tensor info
+            if (!params.no_alloc) {
+                cur->data = (char *) data->data + info.offset;
+            }
+        }
+
+        if (!ok) {
+            fprintf(stderr, "%s: failed to create tensors\n", __func__);
+            ggml_free(ctx_data);
+            *params.ctx = nullptr;
+            gguf_free(ctx);
+            return nullptr;
+        }
+
+        ggml_set_no_alloc(ctx_data, params.no_alloc);
+    }
+
+    return ctx;
+}
+
+struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) {
+    FILE * file = ggml_fopen(fname, "rb");
+
+    if (!file) {
+        fprintf(stderr, "%s: failed to open GGUF file '%s'\n", __func__, fname);
+        return nullptr;
+    }
+
+    struct gguf_context * result = gguf_init_from_file_impl(file, params);
+    fclose(file);
+    return result;
+}
+
+void gguf_free(struct gguf_context * ctx) {
+    if (ctx == nullptr) {
+        return;
+    }
+    delete ctx;
+}
+
+const char * gguf_type_name(enum gguf_type type) {
+    auto it = GGUF_TYPE_NAME.find(type);
+    return it == GGUF_TYPE_NAME.end() ? nullptr : it->second;
+}
+
+uint32_t gguf_get_version(const struct gguf_context * ctx) {
+    return ctx->version;
+}
+
+size_t gguf_get_alignment(const struct gguf_context * ctx) {
+    return ctx->alignment;
+}
+
+size_t gguf_get_data_offset(const struct gguf_context * ctx) {
+    return ctx->offset;
+}
+
+int64_t gguf_get_n_kv(const struct gguf_context * ctx) {
+    return ctx->kv.size();
+}
+
+int64_t gguf_find_key(const struct gguf_context * ctx, const char * key) {
+    // return -1 if key not found
+    int64_t keyfound = -1;
+
+    const int64_t n_kv = gguf_get_n_kv(ctx);
+
+    for (int64_t i = 0; i < n_kv; ++i) {
+        if (strcmp(key, gguf_get_key(ctx, i)) == 0) {
+            keyfound = i;
+            break;
+        }
+    }
+
+    return keyfound;
+}
+
+const char * gguf_get_key(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    return ctx->kv[key_id].get_key().c_str();
+}
+
+enum gguf_type gguf_get_kv_type(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    return ctx->kv[key_id].is_array ? GGUF_TYPE_ARRAY : ctx->kv[key_id].get_type();
+}
+
+enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].is_array);
+    return ctx->kv[key_id].get_type();
+}
+
+const void * gguf_get_arr_data(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
+    return ctx->kv[key_id].data.data();
+}
+
+const char * gguf_get_arr_str(const struct gguf_context * ctx, int64_t key_id, size_t i) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_type() == GGUF_TYPE_STRING);
+    return ctx->kv[key_id].data_string[i].c_str();
+}
+
+size_t gguf_get_arr_n(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+
+    if (ctx->kv[key_id].type == GGUF_TYPE_STRING) {
+        return ctx->kv[key_id].data_string.size();
+    }
+
+    const size_t type_size = gguf_type_size(ctx->kv[key_id].type);
+    GGML_ASSERT(ctx->kv[key_id].data.size() % type_size == 0);
+    return ctx->kv[key_id].data.size() / type_size;
+}
+
+uint8_t gguf_get_val_u8(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+int8_t gguf_get_val_i8(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+uint16_t gguf_get_val_u16(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+int16_t gguf_get_val_i16(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+uint32_t gguf_get_val_u32(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+int32_t gguf_get_val_i32(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+float gguf_get_val_f32(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+uint64_t gguf_get_val_u64(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+int64_t gguf_get_val_i64(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+double gguf_get_val_f64(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+bool gguf_get_val_bool(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val();
+}
+
+const char * gguf_get_val_str(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    return ctx->kv[key_id].get_val().c_str();
+}
+
+const void * gguf_get_val_data(const struct gguf_context * ctx, int64_t key_id) {
+    GGML_ASSERT(key_id >= 0 && key_id < gguf_get_n_kv(ctx));
+    GGML_ASSERT(ctx->kv[key_id].get_ne() == 1);
+    GGML_ASSERT(ctx->kv[key_id].get_type() != GGUF_TYPE_STRING);
+    return ctx->kv[key_id].data.data();
+}
+
+int64_t gguf_get_n_tensors(const struct gguf_context * ctx) {
+    return ctx->info.size();
+}
+
+int64_t gguf_find_tensor(const struct gguf_context * ctx, const char * name) {
+    // return -1 if tensor not found
+    int64_t tensor_id = -1;
+
+    const int64_t n_tensors = gguf_get_n_tensors(ctx);
+
+    for (int64_t i = 0; i < n_tensors; ++i) {
+        if (strcmp(name, gguf_get_tensor_name(ctx, i)) == 0) {
+            tensor_id = i;
+            break;
+        }
+    }
+
+    return tensor_id;
+}
+
+size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int64_t tensor_id) {
+    GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx));
+    return ctx->info[tensor_id].offset;
+}
+
+const char * gguf_get_tensor_name(const struct gguf_context * ctx, int64_t tensor_id) {
+    GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx));
+    return ctx->info[tensor_id].t.name;
+}
+
+enum ggml_type gguf_get_tensor_type(const struct gguf_context * ctx, int64_t tensor_id) {
+    GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx));
+    return ctx->info[tensor_id].t.type;
+}
+
+size_t gguf_get_tensor_size(const struct gguf_context * ctx, int64_t tensor_id) {
+    GGML_ASSERT(tensor_id >= 0 && tensor_id < gguf_get_n_tensors(ctx));
+    return ggml_nbytes(&ctx->info[tensor_id].t);
+}
+
+int64_t gguf_remove_key(struct gguf_context * ctx, const char * key) {
+    const int64_t key_id = gguf_find_key(ctx, key);
+    if (key_id >= 0) {
+        ctx->kv.erase(ctx->kv.begin() + key_id);
+    }
+    return key_id;
+}
+
+template
+static void gguf_check_reserved_keys(const std::string & key, const T val) {
+    if (key == GGUF_KEY_GENERAL_ALIGNMENT) {
+        if constexpr (std::is_same::value) {
+            GGML_ASSERT(val > 0 && (val & (val - 1)) == 0 && GGUF_KEY_GENERAL_ALIGNMENT " must be power of 2");
+        } else {
+            GGML_ABORT(GGUF_KEY_GENERAL_ALIGNMENT " must be type u32");
+        }
+    }
+}
+
+void gguf_set_val_u8(struct gguf_context * ctx, const char * key, uint8_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_i8(struct gguf_context * ctx, const char * key, int8_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_u16(struct gguf_context * ctx, const char * key, uint16_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_i16(struct gguf_context * ctx, const char * key, int16_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_u32(struct gguf_context * ctx, const char * key, uint32_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_i32(struct gguf_context * ctx, const char * key, int32_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_f32(struct gguf_context * ctx, const char * key, float val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_u64(struct gguf_context * ctx, const char * key, uint64_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_i64(struct gguf_context * ctx, const char * key, int64_t val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_f64(struct gguf_context * ctx, const char * key, double val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, val);
+}
+
+void gguf_set_val_str(struct gguf_context * ctx, const char * key, const char * val) {
+    gguf_check_reserved_keys(key, val);
+    gguf_remove_key(ctx, key);
+    ctx->kv.emplace_back(key, std::string(val));
+}
+
+void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, size_t n) {
+    gguf_check_reserved_keys(key, data);
+    gguf_remove_key(ctx, key);
+
+    const size_t nbytes = n*gguf_type_size(type);
+    std::vector tmp(nbytes);
+    if (!tmp.empty()) {
+        memcpy(tmp.data(), data, nbytes);
+    }
+    ctx->kv.emplace_back(key, tmp);
+    ctx->kv.back().cast(type);
+}
+
+void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char ** data, size_t n) {
+    gguf_check_reserved_keys(key, data);
+    gguf_remove_key(ctx, key);
+
+    std::vector tmp(n);
+    for (size_t i = 0; i < n; ++i) {
+        tmp[i] = data[i];
+    }
+    ctx->kv.emplace_back(key, tmp);
+}
+
+// set or add KV pairs from another context
+void gguf_set_kv(struct gguf_context * ctx, const struct gguf_context * src) {
+    const int64_t n_kv = gguf_get_n_kv(src);
+    for (int64_t i = 0; i < n_kv; ++i) {
+        const struct gguf_kv & kv = src->kv[i];
+
+        if (!kv.is_array) {
+            switch (kv.get_type()) {
+                case GGUF_TYPE_UINT8:   gguf_set_val_u8  (ctx, kv.get_key().c_str(), kv.get_val());             break;
+                case GGUF_TYPE_INT8:    gguf_set_val_i8  (ctx, kv.get_key().c_str(), kv.get_val());              break;
+                case GGUF_TYPE_UINT16:  gguf_set_val_u16 (ctx, kv.get_key().c_str(), kv.get_val());            break;
+                case GGUF_TYPE_INT16:   gguf_set_val_i16 (ctx, kv.get_key().c_str(), kv.get_val());             break;
+                case GGUF_TYPE_UINT32:  gguf_set_val_u32 (ctx, kv.get_key().c_str(), kv.get_val());            break;
+                case GGUF_TYPE_INT32:   gguf_set_val_i32 (ctx, kv.get_key().c_str(), kv.get_val());             break;
+                case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (ctx, kv.get_key().c_str(), kv.get_val());               break;
+                case GGUF_TYPE_UINT64:  gguf_set_val_u64 (ctx, kv.get_key().c_str(), kv.get_val());            break;
+                case GGUF_TYPE_INT64:   gguf_set_val_i64 (ctx, kv.get_key().c_str(), kv.get_val());             break;
+                case GGUF_TYPE_FLOAT64: gguf_set_val_f64 (ctx, kv.get_key().c_str(), kv.get_val());              break;
+                case GGUF_TYPE_BOOL:    gguf_set_val_bool(ctx, kv.get_key().c_str(), kv.get_val());                break;
+                case GGUF_TYPE_STRING:  gguf_set_val_str (ctx, kv.get_key().c_str(), kv.get_val().c_str()); break;
+                case GGUF_TYPE_ARRAY:
+                default: GGML_ABORT("invalid type");
+            }
+            continue;
+        }
+
+        const size_t ne = kv.get_ne();
+
+        switch (kv.get_type()) {
+            case GGUF_TYPE_UINT8:
+            case GGUF_TYPE_INT8:
+            case GGUF_TYPE_UINT16:
+            case GGUF_TYPE_INT16:
+            case GGUF_TYPE_UINT32:
+            case GGUF_TYPE_INT32:
+            case GGUF_TYPE_FLOAT32:
+            case GGUF_TYPE_UINT64:
+            case GGUF_TYPE_INT64:
+            case GGUF_TYPE_FLOAT64:
+            case GGUF_TYPE_BOOL: {
+                gguf_set_arr_data(ctx, kv.get_key().c_str(), kv.get_type(), kv.data.data(), ne);
+            } break;
+            case GGUF_TYPE_STRING: {
+                std::vector tmp(ne);
+                for (size_t j = 0; j < ne; ++j) {
+                    tmp[j] = kv.data_string[j].c_str();
+                }
+                gguf_set_arr_str(ctx, kv.get_key().c_str(), tmp.data(), ne);
+            } break;
+            case GGUF_TYPE_ARRAY:
+            default: GGML_ABORT("invalid type");
+        }
+    }
+}
+
+void gguf_add_tensor(
+             struct gguf_context * ctx,
+        const struct ggml_tensor * tensor) {
+    GGML_ASSERT(tensor);
+    if (gguf_find_tensor(ctx, tensor->name) != -1) {
+        GGML_ABORT("duplicate tensor name: %s", tensor->name);
+    }
+
+    struct gguf_tensor_info ti;
+    ti.t = *tensor;
+    ti.offset = ctx->info.empty() ? 0 :
+        ctx->info.back().offset + GGML_PAD(ggml_nbytes(&ctx->info.back().t), ctx->alignment);
+    ctx->info.push_back(ti);
+}
+
+void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type) {
+    const int64_t tensor_id = gguf_find_tensor(ctx, name);
+    if (tensor_id < 0) {
+        GGML_ABORT("tensor not found: %s", name);
+    }
+    struct ggml_tensor * tensor = &ctx->info[tensor_id].t;
+    const size_t  type_size = ggml_type_size(type);
+    const int64_t blck_size = ggml_blck_size(type);
+
+    tensor->type = type;
+    GGML_ASSERT(tensor->ne[0] % blck_size == 0 && "tensor row size not divisible by block size of new type");
+
+    tensor->nb[0] = type_size;
+    tensor->nb[1] = tensor->nb[0]*(tensor->ne[0]/blck_size);
+    for (int i = 2; i < GGML_MAX_DIMS; i++) {
+        tensor->nb[i] = tensor->nb[i - 1]*tensor->ne[i - 1];
+    }
+
+    // update offsets
+    const int64_t n_tensors = gguf_get_n_tensors(ctx);
+    for (int64_t i = tensor_id + 1; i < n_tensors; ++i) {
+        ctx->info[i].offset = ctx->info[i - 1].offset + GGML_PAD(ggml_nbytes(&ctx->info[i - 1].t), ctx->alignment);
+    }
+}
+
+void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data) {
+    const int64_t tensor_id = gguf_find_tensor(ctx, name);
+    if (tensor_id < 0) {
+        GGML_ABORT("tensor not found: %s", name);
+    }
+
+    ctx->info[tensor_id].t.data = (void *)(uintptr_t)data; // double cast suppresses warning about casting away const
+}
+
+struct gguf_writer {
+    std::vector & buf;
+
+    gguf_writer(std::vector & buf) : buf(buf) {}
+
+    template 
+    void write(const T & val) const {
+        for (size_t i = 0; i < sizeof(val); ++i) {
+            buf.push_back(reinterpret_cast(&val)[i]);
+        }
+    }
+
+    void write(const std::vector & val) const {
+        buf.insert(buf.end(), val.begin(), val.end());
+    }
+
+    void write(const bool & val) const {
+        const int8_t val8 = val ? 1 : 0;
+        write(val8);
+    }
+
+    void write(const std::string & val) const {
+        {
+            const uint64_t n = val.length();
+            write(n);
+        }
+        for (size_t i = 0; i < val.length(); ++i) {
+            buf.push_back(reinterpret_cast(val.data())[i]);
+        }
+    }
+
+    void write(const char * val) const {
+        write(std::string(val));
+    }
+
+    void write(const enum ggml_type & val) const {
+        write(int32_t(val));
+    }
+
+    void write(const enum gguf_type & val) const {
+        write(int32_t(val));
+    }
+
+    void write(const struct gguf_kv & kv) const {
+        const uint64_t ne = kv.get_ne();
+
+        write(kv.get_key());
+
+        if (kv.is_array) {
+            write(GGUF_TYPE_ARRAY);
+            write(kv.get_type());
+            write(ne);
+        } else {
+            write(kv.get_type());
+        }
+
+        switch (kv.get_type()) {
+            case GGUF_TYPE_UINT8:
+            case GGUF_TYPE_INT8:
+            case GGUF_TYPE_UINT16:
+            case GGUF_TYPE_INT16:
+            case GGUF_TYPE_UINT32:
+            case GGUF_TYPE_INT32:
+            case GGUF_TYPE_FLOAT32:
+            case GGUF_TYPE_UINT64:
+            case GGUF_TYPE_INT64:
+            case GGUF_TYPE_FLOAT64: {
+                write(kv.data);
+            } break;
+            case GGUF_TYPE_BOOL: {
+                for (size_t i = 0; i < ne; ++i) {
+                    write(kv.get_val(i));
+                }
+            } break;
+            case GGUF_TYPE_STRING: {
+                for (size_t i = 0; i < ne; ++i) {
+                    write(kv.get_val(i));
+                }
+            } break;
+            case GGUF_TYPE_ARRAY:
+            default: GGML_ABORT("invalid type");
+        }
+    }
+
+    void write_tensor_meta(const struct gguf_tensor_info & info) const {
+        write(info.t.name);
+
+        const uint32_t n_dims = ggml_n_dims(&info.t);
+        write(n_dims);
+
+        for (uint32_t j = 0; j < n_dims; ++j) {
+            write(info.t.ne[j]);
+        }
+        write(info.t.type);
+        write(info.offset);
+    }
+
+    void pad(const size_t alignment) const {
+        while (buf.size() % alignment != 0) {
+            const int8_t zero = 0;
+            write(zero);
+        }
+    }
+
+    void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) const {
+        GGML_ASSERT(buf.size() - offset_data == info.offset);
+
+        GGML_ASSERT(ggml_is_contiguous(&info.t));
+        const size_t offset = buf.size();
+        const size_t nbytes = ggml_nbytes(&info.t);
+
+        buf.resize(offset + nbytes);
+        if (info.t.buffer) {
+            ggml_backend_tensor_get(&info.t, buf.data() + offset, 0, nbytes);
+        } else {
+            GGML_ASSERT(info.t.data);
+            memcpy(buf.data() + offset, info.t.data, nbytes);
+        }
+
+        pad(alignment);
+    }
+};
+
+void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & buf, bool only_meta) {
+    const struct gguf_writer gw(buf);
+
+    const int64_t n_kv      = gguf_get_n_kv(ctx);
+    const int64_t n_tensors = gguf_get_n_tensors(ctx);
+
+    // write header
+    gw.write(GGUF_MAGIC[0]);
+    gw.write(GGUF_MAGIC[1]);
+    gw.write(GGUF_MAGIC[2]);
+    gw.write(GGUF_MAGIC[3]);
+    gw.write(ctx->version);
+    gw.write(n_tensors);
+    gw.write(n_kv);
+
+    // write key-value pairs
+    for (int64_t i = 0; i < n_kv; ++i) {
+        gw.write(ctx->kv[i]);
+    }
+
+    // write tensor info
+    for (int64_t i = 0; i < n_tensors; ++i) {
+        gw.write_tensor_meta(ctx->info[i]);
+    }
+
+    // we require the data section to be aligned
+    gw.pad(ctx->alignment);
+
+    if (only_meta) {
+        return;
+    }
+
+    const size_t offset_data = gw.buf.size();
+
+    // write tensor data
+    for (int64_t i = 0; i < n_tensors; ++i) {
+        gw.write_tensor_data(ctx->info[i], offset_data, ctx->alignment);
+    }
+}
+
+bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) {
+    FILE * file = ggml_fopen(fname, "wb");
+
+    if (!file) {
+        fprintf(stderr, "%s: failed to open file '%s' for writing GGUF data\n", __func__, fname);
+        return false;
+    }
+
+    std::vector buf;
+    gguf_write_to_buf(ctx, buf, only_meta);
+    const bool ok = fwrite(buf.data(), 1, buf.size(), file) == buf.size();
+    fclose(file);
+    return ok;
+}
+
+size_t gguf_get_meta_size(const struct gguf_context * ctx) {
+    // only return size
+    std::vector buf;
+    gguf_write_to_buf(ctx, buf, /*only_meta =*/ true);
+    return buf.size();
+}
+
+void gguf_get_meta_data(const struct gguf_context * ctx, void * data) {
+    std::vector buf;
+    gguf_write_to_buf(ctx, buf, /*only_meta =*/ true);
+    memcpy(data, buf.data(), buf.size());
+}
diff --git a/scripts/sync-ggml-am.sh b/scripts/sync-ggml-am.sh
index 29038c8683a..f5defa75920 100755
--- a/scripts/sync-ggml-am.sh
+++ b/scripts/sync-ggml-am.sh
@@ -157,8 +157,8 @@ if [ -f $SRC_WHISPER/ggml-src.patch ]; then
         -e 's/([[:space:]]|[ab]\/)src\/ggml-rpc\//\1ggml\/src\/ggml-rpc\//g' \
         -e 's/([[:space:]]|[ab]\/)src\/ggml-sycl\//\1ggml\/src\/ggml-sycl\//g' \
         -e 's/([[:space:]]|[ab]\/)src\/ggml-vulkan\//\1ggml\/src\/ggml-vulkan\//g' \
-        -e 's/([[:space:]]|[ab]\/)include\/ggml(.*)\.h/\1ggml\/include\/ggml\2.h/g' \
-        -e 's/([[:space:]]|[ab]\/)include\/gguf(.*)\.h/\1ggml\/include\/gguf\2.h/g' \
+        -e 's/(^[[:space:]]|[ab]\/)include\/ggml(.*)\.h/\1ggml\/include\/ggml\2.h/g' \
+        -e 's/(^[[:space:]]|[ab]\/)include\/gguf(.*)\.h/\1ggml\/include\/gguf\2.h/g' \
         -e 's/(^[[:space:]]|[ab]\/)examples\/common\.h/\1examples\/common.h/g' \
         -e 's/(^[[:space:]]|[ab]\/)examples\/common\.cpp/\1examples\/common.cpp/g' \
         -e 's/(^[[:space:]]|[ab]\/)examples\/common-ggml\.h/\1examples\/common-ggml.h/g' \
diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last
index 1a052dd9927..1a990ac68c3 100644
--- a/scripts/sync-ggml.last
+++ b/scripts/sync-ggml.last
@@ -1 +1 @@
-e61b9f5df05f44714128507f31128b89c7fb3134
+f33c42adaf5d7fc24093f3975c162d905d8e111a