Skip to content

Commit

Permalink
server : add lora hotswap endpoint (WIP) (ggerganov#8857)
Browse files Browse the repository at this point in the history
* server : add lora hotswap endpoint

* handle lora_no_apply

* fix build

* updae docs

* clean up struct def

* fix build

* add LoRA test

* fix style
  • Loading branch information
ngxson authored Aug 6, 2024
1 parent 641f5dd commit 1e6f655
Show file tree
Hide file tree
Showing 9 changed files with 251 additions and 92 deletions.
60 changes: 42 additions & 18 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,14 +684,24 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
if (arg == "--lora") {
CHECK_ARG
params.lora_adapter.emplace_back(argv[i], 1.0f);
params.lora_adapters.push_back({
std::string(argv[i]),
1.0,
});
return true;
}
if (arg == "--lora-scaled") {
CHECK_ARG
const char* lora_adapter = argv[i];
std::string lora_adapter = argv[i];
CHECK_ARG
params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i]));
params.lora_adapters.push_back({
lora_adapter,
std::stof(argv[i]),
});
return true;
}
if (arg == "--lora-init-without-apply") {
params.lora_init_without_apply = true;
return true;
}
if (arg == "--control-vector") {
Expand Down Expand Up @@ -1654,6 +1664,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
options.push_back({ "server", "-sps, --slot-prompt-similarity SIMILARITY",
"how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity });
options.push_back({ "server", " --lora-init-without-apply", "load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: %s)", params.lora_init_without_apply ? "enabled" : "disabled"});

#ifndef LOG_DISABLE_LOGS
options.push_back({ "logging" });
Expand Down Expand Up @@ -2091,17 +2102,22 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
}
}

for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
float lora_scale = std::get<1>(params.lora_adapter[i]);
auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str());
if (adapter == nullptr) {
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
// load and optionally apply lora adapters
for (auto & la : params.lora_adapters) {
llama_lora_adapter_container loaded_la;
loaded_la.path = la.path;
loaded_la.scale = la.scale;
loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str());
if (loaded_la.adapter == nullptr) {
fprintf(stderr, "%s: error: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
llama_free(lctx);
llama_free_model(model);
return iparams;
}
llama_lora_adapter_set(lctx, adapter, lora_scale);
iparams.lora_adapters.push_back(loaded_la); // copy to list of loaded adapters
}
if (!params.lora_init_without_apply) {
llama_lora_adapters_apply(lctx, iparams.lora_adapters);
}

if (params.ignore_eos) {
Expand Down Expand Up @@ -2140,6 +2156,15 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
return iparams;
}

void llama_lora_adapters_apply(struct llama_context * ctx, std::vector<llama_lora_adapter_container> & lora_adapters) {
llama_lora_adapter_clear(ctx);
for (auto & la : lora_adapters) {
if (la.scale != 0.0f) {
llama_lora_adapter_set(ctx, la.adapter, la.scale);
}
}
}

struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) {
auto mparams = llama_model_default_params();

Expand Down Expand Up @@ -3162,19 +3187,18 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
}

fprintf(stream, "lora:\n");
for (std::tuple<std::string, float> la : params.lora_adapter) {
if (std::get<1>(la) != 1.0f) {
continue;
for (auto & la : params.lora_adapters) {
if (la.scale == 1.0f) {
fprintf(stream, " - %s\n", la.path.c_str());
}
fprintf(stream, " - %s\n", std::get<0>(la).c_str());
}
fprintf(stream, "lora_scaled:\n");
for (std::tuple<std::string, float> la : params.lora_adapter) {
if (std::get<1>(la) == 1.0f) {
continue;
for (auto & la : params.lora_adapters) {
if (la.scale != 1.0f) {
fprintf(stream, " - %s: %f\n", la.path.c_str(), la.scale);
}
fprintf(stream, " - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la));
}
fprintf(stream, "lora_init_without_apply: %s # default: false\n", params.lora_init_without_apply ? "true" : "false");
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep);
fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);
Expand Down
19 changes: 16 additions & 3 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@

#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"

struct llama_lora_adapter_info {
std::string path;
float scale;
};

struct llama_lora_adapter_container : llama_lora_adapter_info {
struct llama_lora_adapter * adapter;
};

// build info
extern int LLAMA_BUILD_NUMBER;
extern char const * LLAMA_COMMIT;
Expand Down Expand Up @@ -126,8 +135,8 @@ struct gpt_params {
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
std::vector<llama_model_kv_override> kv_overrides;

// TODO: avoid tuple, use struct
std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply)
std::vector<llama_lora_adapter_info> lora_adapters; // lora adapter path with user defined scale

std::vector<llama_control_vector_load_info> control_vectors; // control vector with user defined scale

Expand Down Expand Up @@ -309,8 +318,9 @@ std::string fs_get_cache_file(const std::string & filename);
//

struct llama_init_result {
struct llama_model * model = nullptr;
struct llama_model * model = nullptr;
struct llama_context * context = nullptr;
std::vector<llama_lora_adapter_container> lora_adapters;
};

struct llama_init_result llama_init_from_gpt_params(gpt_params & params);
Expand All @@ -321,6 +331,9 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params);
struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * hf_token, const struct llama_model_params & params);

// clear LoRA adapters from context, then apply new list of adapters
void llama_lora_adapters_apply(struct llama_context * ctx, std::vector<llama_lora_adapter_container> & lora_adapters);

// Batch utils

void llama_batch_clear(struct llama_batch & batch);
Expand Down
10 changes: 5 additions & 5 deletions examples/export-lora/export-lora.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ struct lora_merge_ctx {

lora_merge_ctx(
std::string & base_fname,
std::vector<std::tuple<std::string, float>> & lora_files,
std::vector<llama_lora_adapter_info> & lora_files,
std::string & outfile,
int n_threads) : base_model(base_fname, 0), n_threads(n_threads), fout(outfile, std::ios::binary) {
fout.exceptions(std::ofstream::failbit); // fail fast on write errors
Expand All @@ -144,9 +144,9 @@ struct lora_merge_ctx {
throw std::runtime_error("split model is not yet supported");
}

for (auto lora_inp : lora_files) {
auto fname = std::get<0>(lora_inp);
auto scale = std::get<1>(lora_inp);
for (auto & lora_inp : lora_files) {
auto fname = lora_inp.path;
auto scale = lora_inp.scale;
std::unique_ptr<file_input> adapter(new file_input(fname, scale));
check_metadata_lora(adapter.get());
adapters.push_back(std::move(adapter));
Expand Down Expand Up @@ -407,7 +407,7 @@ int main(int argc, char ** argv) {

g_verbose = (params.verbosity == 1);
try {
lora_merge_ctx ctx(params.model, params.lora_adapter, params.lora_outfile, params.n_threads);
lora_merge_ctx ctx(params.model, params.lora_adapters, params.lora_outfile, params.n_threads);
ctx.run_merge();
} catch (const std::exception & err) {
fprintf(stderr, "%s\n", err.what());
Expand Down
Loading

0 comments on commit 1e6f655

Please sign in to comment.