Skip to content

Commit

Permalink
common : add -hfd option for the draft model (ggerganov#11318)
Browse files Browse the repository at this point in the history
* common : add -hfd option for the draft model

* cont : fix env var

* cont : more fixes
  • Loading branch information
ggerganov authored Jan 20, 2025
1 parent aea8ddd commit 80d0d6b
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 6 deletions.
17 changes: 13 additions & 4 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ static void common_params_handle_model_default(
const std::string & model_url,
std::string & hf_repo,
std::string & hf_file,
const std::string & hf_token) {
const std::string & hf_token,
const std::string & model_default) {
if (!hf_repo.empty()) {
// short-hand to avoid specifying --hf-file -> default it to --model
if (hf_file.empty()) {
Expand Down Expand Up @@ -163,7 +164,7 @@ static void common_params_handle_model_default(
model = fs_get_cache_file(string_split<std::string>(f, '/').back());
}
} else if (model.empty()) {
model = DEFAULT_MODEL_PATH;
model = model_default;
}
}

Expand Down Expand Up @@ -299,8 +300,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
}

// TODO: refactor model params in a common struct
common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token);
common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token);
common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file, params.hf_token, DEFAULT_MODEL_PATH);
common_params_handle_model_default(params.speculative.model, params.speculative.model_url, params.speculative.hf_repo, params.speculative.hf_file, params.hf_token, "");
common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file, params.hf_token, "");

if (params.escape) {
string_process_escapes(params.prompt);
Expand Down Expand Up @@ -1629,6 +1631,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.hf_repo = value;
}
).set_env("LLAMA_ARG_HF_REPO"));
add_opt(common_arg(
{"-hfd", "-hfrd", "--hf-repo-draft"}, "<user>/<model>[:quant]",
"Same as --hf-repo, but for the draft model (default: unused)",
[](common_params & params, const std::string & value) {
params.speculative.hf_repo = value;
}
).set_env("LLAMA_ARG_HFD_REPO"));
add_opt(common_arg(
{"-hff", "--hf-file"}, "FILE",
"Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)",
Expand Down
8 changes: 7 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,11 @@ struct common_params_speculative {
struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;

std::string model = ""; // draft model for speculative decoding // NOLINT
std::string hf_repo = ""; // HF repo // NOLINT
std::string hf_file = ""; // HF file // NOLINT

std::string model = ""; // draft model for speculative decoding // NOLINT
std::string model_url = ""; // model url to download // NOLINT
};

struct common_params_vocoder {
Expand Down Expand Up @@ -508,12 +512,14 @@ struct llama_model * common_load_model_from_url(
const std::string & local_path,
const std::string & hf_token,
const struct llama_model_params & params);

struct llama_model * common_load_model_from_hf(
const std::string & repo,
const std::string & remote_path,
const std::string & local_path,
const std::string & hf_token,
const struct llama_model_params & params);

std::pair<std::string, std::string> common_get_hf_file(
const std::string & hf_repo_with_tag,
const std::string & hf_token);
Expand Down
5 changes: 4 additions & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1728,13 +1728,16 @@ struct server_context {
add_bos_token = llama_vocab_get_add_bos(vocab);
has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;

if (!params_base.speculative.model.empty()) {
if (!params_base.speculative.model.empty() || !params_base.speculative.hf_repo.empty()) {
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());

auto params_dft = params_base;

params_dft.devices = params_base.speculative.devices;
params_dft.hf_file = params_base.speculative.hf_file;
params_dft.hf_repo = params_base.speculative.hf_repo;
params_dft.model = params_base.speculative.model;
params_dft.model_url = params_base.speculative.model_url;
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
params_dft.n_parallel = 1;
Expand Down

0 comments on commit 80d0d6b

Please sign in to comment.