Skip to content

Commit

Permalink
llama : expose llama_model_n_head_kv in the API (ggml-org#11997)
Browse files Browse the repository at this point in the history
It's useful to be able to have this from the library layer as it's a key
parameter of the model (e.g. to figure out how much KV cache memory is
needed).
  • Loading branch information
vlovich authored and orca-zhang committed Feb 26, 2025
1 parent 3ba18ef commit 0714a04
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ extern "C" {
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);

// Get the model's RoPE frequency scaling factor
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
Expand Down
4 changes: 4 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3838,6 +3838,10 @@ int32_t llama_model_n_head(const struct llama_model * model) {
return model->hparams.n_head();
}

int32_t llama_model_n_head_kv(const struct llama_model * model) {
return model->hparams.n_head_kv();
}

// deprecated
int32_t llama_n_ctx_train(const struct llama_model * model) {
return llama_model_n_ctx_train(model);
Expand Down

0 comments on commit 0714a04

Please sign in to comment.