Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for user specified embedding pooling type #5849

Merged
merged 2 commits into from
Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,16 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.yarn_beta_slow = std::stof(argv[i]);
} else if (arg == "--pooling") {
if (++i >= argc) {
invalid_param = true;
break;
}
std::string value(argv[i]);
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
else { invalid_param = true; break; }
} else if (arg == "--defrag-thold" || arg == "-dt") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -1014,6 +1024,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
printf(" --pooling {none,mean,cls}\n");
printf(" pooling type for embeddings, use model default if unspecified\n");
printf(" -dt N, --defrag-thold N\n");
printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold);
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
Expand Down Expand Up @@ -1296,6 +1308,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
cparams.pooling_type = params.pooling_type;
cparams.defrag_thold = params.defrag_thold;
cparams.offload_kqv = !params.no_kv_offload;

Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ struct gpt_params {
float defrag_thold = -1.0f; // KV cache defragmentation threshold
int32_t rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings

// // sampling parameters
struct llama_sampling_params sparams;
Expand Down
18 changes: 9 additions & 9 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1644,16 +1644,17 @@ def set_gguf_parameters(self):
self.gguf_writer.add_causal_attention(False)

# get pooling path
with open(self.dir_model / "modules.json", encoding="utf-8") as f:
modules = json.load(f)
pooling_path = None
for mod in modules:
if mod["type"] == "sentence_transformers.models.Pooling":
pooling_path = mod["path"]
break
module_path = self.dir_model / "modules.json"
if module_path.is_file():
with open(module_path, encoding="utf-8") as f:
modules = json.load(f)
for mod in modules:
if mod["type"] == "sentence_transformers.models.Pooling":
pooling_path = mod["path"]
break

# get pooling type
pooling_type = gguf.PoolingType.NONE
if pooling_path is not None:
with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f:
pooling = json.load(f)
Expand All @@ -1663,8 +1664,7 @@ def set_gguf_parameters(self):
pooling_type = gguf.PoolingType.CLS
else:
raise NotImplementedError("Only MEAN and CLS pooling types supported")

self.gguf_writer.add_pooling_type(pooling_type)
self.gguf_writer.add_pooling_type(pooling_type)

def set_vocab(self):
path = self.dir_model
Expand Down
28 changes: 20 additions & 8 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1683,7 +1683,7 @@ struct llama_cparams {
float defrag_thold;

bool offload_kqv;
bool do_pooling;
enum llama_pooling_type pooling_type;

ggml_backend_sched_eval_callback cb_eval;
void * cb_eval_user_data;
Expand Down Expand Up @@ -2931,7 +2931,11 @@ template<>
bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) {
uint32_t tmp;
const bool found = get_key(kid, tmp, required);
result = (enum llama_pooling_type) tmp;
if (found) {
result = (enum llama_pooling_type) tmp;
} else {
result = LLAMA_POOLING_TYPE_UNSPECIFIED;
}
return found;
}

Expand Down Expand Up @@ -3208,7 +3212,7 @@ static void llm_load_hparams(
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);

switch (hparams.n_layer) {
case 3:
Expand Down Expand Up @@ -5173,7 +5177,7 @@ struct llm_build_context {
n_kv (worst_case ? n_ctx : kv_self.n),
kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
n_orig_ctx (cparams.n_yarn_orig_ctx),
pooling_type (cparams.do_pooling ? hparams.pooling_type : LLAMA_POOLING_TYPE_NONE),
pooling_type (cparams.pooling_type),
rope_type (hparams.rope_type),
cb (cb),
buf_compute_meta (lctx.buf_compute_meta) {
Expand Down Expand Up @@ -8013,7 +8017,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}

if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
const int64_t n_tokens = batch.n_tokens;

GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
Expand Down Expand Up @@ -8041,7 +8045,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}

if (cparams.do_pooling && hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
const int64_t n_tokens = batch.n_tokens;

GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
Expand Down Expand Up @@ -11859,7 +11863,7 @@ struct llama_context_params llama_context_default_params() {
/*.logits_all =*/ false,
/*.embedding =*/ false,
/*.offload_kqv =*/ true,
/*.do_pooling =*/ true,
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
};
Expand Down Expand Up @@ -12010,7 +12014,7 @@ struct llama_context * llama_new_context_with_model(
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.defrag_thold = params.defrag_thold;
cparams.offload_kqv = params.offload_kqv;
cparams.do_pooling = params.do_pooling;
cparams.pooling_type = params.pooling_type;

cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
Expand All @@ -12036,6 +12040,14 @@ struct llama_context * llama_new_context_with_model(
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
}

if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
} else {
cparams.pooling_type = hparams.pooling_type;
}
}

if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL);
}
Expand Down
3 changes: 2 additions & 1 deletion llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ extern "C" {
};

enum llama_pooling_type {
LLAMA_POOLING_TYPE_UNSPECIFIED = -1,
LLAMA_POOLING_TYPE_NONE = 0,
LLAMA_POOLING_TYPE_MEAN = 1,
LLAMA_POOLING_TYPE_CLS = 2,
Expand Down Expand Up @@ -258,7 +259,7 @@ extern "C" {
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embedding; // embedding mode only
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool do_pooling; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this next to rope_scaling_type above and pass as int32_t

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I have it be an int32_t in llama_cparams and llm_build_context as well? Otherwise will need a static_cast<llama_pooling_type> in llama_new_context_with_model.


// Abort callback
// if it returns true, execution of llama_decode() will be aborted
Expand Down
Loading