From 1ecb6a6999cc47afea6d8ca63c25a657685f68be Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Wed, 29 May 2024 06:37:54 +0000 Subject: [PATCH 1/5] server : Smart selection of available slot using Longest Common Substring --- examples/server/server.cpp | 141 +++++++++++++++++++++++++++++++++---- examples/server/utils.hpp | 49 +++++++++++++ 2 files changed, 175 insertions(+), 15 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index fc6d90848f099..209df41c8e47c 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -144,6 +144,7 @@ struct server_params { bool slots_endpoint = true; bool metrics_endpoint = false; std::string slot_save_path; + float lcs_similarity = 0.0f; }; struct server_slot { @@ -670,6 +671,9 @@ struct server_context { server_metrics metrics; + // Longest Common Substring similarity for slot selection + float lcs_similarity = 0.0f; + ~server_context() { if (ctx) { llama_free(ctx); @@ -818,24 +822,88 @@ struct server_context { return prompt_tokens; } - server_slot * get_slot(int id) { - int64_t t_last = ggml_time_us(); - - server_slot * last_used = nullptr; - + server_slot * get_slot_by_id(int id) { for (server_slot & slot : slots) { - if (slot.id == id && slot.available()) { + if (slot.id == id) { return &slot; } + } + + return nullptr; + } + + server_slot * get_available_slot(const std::string & prompt) { + server_slot * ret = nullptr; + + // find the slot that has at least n% prompt similarity + if (ret == nullptr && lcs_similarity != 0.0f && !prompt.empty()) { + int max_lcs_len = 0; + float similarity = 0; + + for (server_slot & slot : slots) { + // skip the slot if it is not available + if (!slot.available()) { + continue; + } + + // skip the slot if it does not contains prompt + if (!slot.prompt.is_string()) { + continue; + } + + // current slot's prompt + std::string slot_prompt = slot.prompt.get(); + + // length of the current slot's prompt + int slot_prompt_len = slot_prompt.size(); + + // length of the longest common substring between the current slot's prompt and the input prompt + int lcs_len = lcs_length(slot_prompt, prompt); + + // fraction of the common substring length compared to the current slot's prompt length + similarity = static_cast(lcs_len) / slot_prompt_len; + + // select the current slot if the criteria match + if (lcs_len > max_lcs_len && similarity > lcs_similarity) { + max_lcs_len = lcs_len; + ret = &slot; + } + } + + if (ret != nullptr) { + LOG_VERBOSE("selected slot by lcs similarity", { + {"id_slot", ret->id}, + {"max_lcs_len", max_lcs_len}, + {"similarity", similarity}, + }); + } + } + + // find the slot that has been least recently used + if (ret == nullptr) { + int64_t t_last = ggml_time_us(); + for (server_slot & slot : slots) { + // skip the slot if it is not available + if (!slot.available()) { + continue; + } + + // select the current slot if the criteria match + if (slot.t_last_used < t_last) { + t_last = slot.t_last_used; + ret = &slot; + } + } - // among all available slots, find the one that has been least recently used - if (slot.available() && slot.t_last_used < t_last) { - last_used = &slot; - t_last = slot.t_last_used; + if (ret != nullptr) { + LOG_VERBOSE("selected slot by lru", { + {"id_slot", ret->id}, + {"t_last", t_last}, + }); } } - return last_used; + return ret; } bool launch_slot_with_task(server_slot & slot, const server_task & task) { @@ -1538,13 +1606,29 @@ struct server_context { switch (task.type) { case SERVER_TASK_TYPE_COMPLETION: { - server_slot * slot = get_slot(json_value(task.data, "id_slot", -1)); + int id_slot = json_value(task.data, "id_slot", -1); + std::string prompt = json_value(task.data, "prompt", std::string()); + + server_slot * slot; + + if (id_slot != -1) { + slot = get_slot_by_id(id_slot); + } else { + slot = get_available_slot(prompt); + } + if (slot == nullptr) { // if no slot is available, we defer this task for processing later LOG_VERBOSE("no slot is available", {{"id_task", task.id}}); queue_tasks.defer(task); break; } + if (!slot->available()) { + // if requested slot is unavailable, we defer this task for processing later + LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); + queue_tasks.defer(task); + break; + } if (task.data.contains("system_prompt")) { std::string sys_prompt = json_value(task.data, "system_prompt", std::string()); @@ -1661,11 +1745,17 @@ struct server_context { case SERVER_TASK_TYPE_SLOT_SAVE: { int id_slot = task.data.at("id_slot"); - server_slot * slot = get_slot(id_slot); + server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); break; } + if (!slot->available()) { + // if requested slot is unavailable, we defer this task for processing later + LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); + queue_tasks.defer(task); + break; + } const size_t token_count = slot->cache_tokens.size(); const int64_t t_start = ggml_time_us(); @@ -1696,11 +1786,17 @@ struct server_context { case SERVER_TASK_TYPE_SLOT_RESTORE: { int id_slot = task.data.at("id_slot"); - server_slot * slot = get_slot(id_slot); + server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); break; } + if (!slot->available()) { + // if requested slot is unavailable, we defer this task for processing later + LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); + queue_tasks.defer(task); + break; + } const int64_t t_start = ggml_time_us(); @@ -1738,11 +1834,17 @@ struct server_context { case SERVER_TASK_TYPE_SLOT_ERASE: { int id_slot = task.data.at("id_slot"); - server_slot * slot = get_slot(id_slot); + server_slot * slot = get_slot_by_id(id_slot); if (slot == nullptr) { send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); break; } + if (!slot->available()) { + // if requested slot is unavailable, we defer this task for processing later + LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); + queue_tasks.defer(task); + break; + } // Erase token cache const size_t n_erased = slot->cache_tokens.size(); @@ -2868,6 +2970,12 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, invalid_param = true; break; } + } else if (arg == "--lcs-similarity") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.lcs_similarity = std::stof(argv[i]); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); server_print_usage(argv[0], default_params, default_sparams); @@ -3039,6 +3147,9 @@ int main(int argc, char ** argv) { log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded"; } + // Longest Common Substring similarity for slot selection + ctx_server.lcs_similarity = sparams.lcs_similarity; + // load the model if (!ctx_server.load_model(params)) { state.store(SERVER_STATE_ERROR); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index d8a2286e4b1df..68e6953dc1da9 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -653,3 +653,52 @@ static json format_error_response(const std::string & message, const enum error_ {"type", type_str}, }; } + +static int lcs_length(const std::string & str1, const std::string & str2) { + // check for empty strings + if (str1.empty() || str2.empty()) { + return 0; + } + + // get the lengths of the input strings + int str1_len = str1.size(); + int str2_len = str2.size(); + + // initialize the maximum length of the longest common subsequence (LCS) + int max_length = 0; + + // use two rows instead of a 2D matrix to optimize space + std::vector prev_row(str2_len + 1, 0); + std::vector curr_row(str2_len + 1, 0); + + // iterate through the characters of str1 + for (int i = 1; i <= str1_len; i++) { + // iterate through the characters of str2 + for (int j = 1; j <= str2_len; j++) { + // if characters at the current positions match + if (str1[i - 1] == str2[j - 1]) { + // if it's the first character of either string, set LCS length to 1 + if (i == 1 || j == 1) { + curr_row[j] = 1; + } else { + // increment LCS length by 1 compared to the previous character + curr_row[j] = prev_row[j - 1] + 1; + } + + // update max_length if necessary + if (curr_row[j] > max_length) { + max_length = curr_row[j]; + } + } else { + // reset LCS length if characters don't match + curr_row[j] = 0; + } + } + + // update the previous row for the next iteration + prev_row = curr_row; + } + + // return the maximum length of the LCS + return max_length; +} From 2df61bf59e746e8b0ed51373784bbef75284567e Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Wed, 5 Jun 2024 08:41:11 +0000 Subject: [PATCH 2/5] add usage --- common/common.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/common/common.cpp b/common/common.cpp index 0cbf8263c5806..42c594bbd29aa 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1839,6 +1839,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param "set custom jinja chat template (default: template taken from model's metadata)\n" "only commonly used templates are accepted:\n" "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" }); + options.push_back({ "server", " --lcs-similarity SIMILARITY", + "how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f)\n", params.lcs_similarity }); #ifndef LOG_DISABLE_LOGS options.push_back({ "logging" }); From f1164112de761b133dee9674ddfd87e6b31832c5 Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Wed, 5 Jun 2024 09:02:01 +0000 Subject: [PATCH 3/5] remove trailing whitespaces --- examples/server/server.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index b4e1e58ab3765..b90a0b8f39b2d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -821,12 +821,12 @@ struct server_context { if (!slot.available()) { continue; } - + // skip the slot if it does not contains prompt if (!slot.prompt.is_string()) { continue; } - + // current slot's prompt std::string slot_prompt = slot.prompt.get(); @@ -1586,7 +1586,7 @@ struct server_context { std::string prompt = json_value(task.data, "prompt", std::string()); server_slot * slot; - + if (id_slot != -1) { slot = get_slot_by_id(id_slot); } else { From 36083dca2c92c92c2d36acaaf766872d0e37730c Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Fri, 7 Jun 2024 14:09:11 +0000 Subject: [PATCH 4/5] Use Longest Common Prefix (LCP) instead of LCS --- common/common.cpp | 8 +++--- common/common.h | 2 +- examples/server/server.cpp | 26 +++++++++--------- examples/server/utils.hpp | 56 +++++--------------------------------- 4 files changed, 25 insertions(+), 67 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 42c594bbd29aa..65448c918d98e 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1460,12 +1460,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.chat_template = argv[i]; return true; } - if (arg == "--lcs-similarity") { + if (arg == "--lcp-similarity") { if (++i >= argc) { invalid_param = true; return true; } - params.lcs_similarity = std::stof(argv[i]); + params.lcp_similarity = std::stof(argv[i]); return true; } if (arg == "-pps") { @@ -1839,8 +1839,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param "set custom jinja chat template (default: template taken from model's metadata)\n" "only commonly used templates are accepted:\n" "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" }); - options.push_back({ "server", " --lcs-similarity SIMILARITY", - "how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f)\n", params.lcs_similarity }); + options.push_back({ "server", " --lcp-similarity SIMILARITY", + "how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f)\n", params.lcp_similarity }); #ifndef LOG_DISABLE_LOGS options.push_back({ "logging" }); diff --git a/common/common.h b/common/common.h index 0c9c592a4d439..0a8a9c0739a00 100644 --- a/common/common.h +++ b/common/common.h @@ -202,7 +202,7 @@ struct gpt_params { std::string slot_save_path; - float lcs_similarity = 0.0f; + float lcp_similarity = 0.0f; // batched-bench params bool is_pp_shared = false; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index b90a0b8f39b2d..802c660c7a005 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -647,8 +647,8 @@ struct server_context { server_metrics metrics; - // Longest Common Substring similarity for slot selection - float lcs_similarity = 0.0f; + // Longest Common Prefix similarity for slot selection + float lcp_similarity = 0.0f; ~server_context() { if (ctx) { @@ -812,8 +812,8 @@ struct server_context { server_slot * ret = nullptr; // find the slot that has at least n% prompt similarity - if (ret == nullptr && lcs_similarity != 0.0f && !prompt.empty()) { - int max_lcs_len = 0; + if (ret == nullptr && lcp_similarity != 0.0f && !prompt.empty()) { + int max_lcp_len = 0; float similarity = 0; for (server_slot & slot : slots) { @@ -833,23 +833,23 @@ struct server_context { // length of the current slot's prompt int slot_prompt_len = slot_prompt.size(); - // length of the longest common substring between the current slot's prompt and the input prompt - int lcs_len = lcs_length(slot_prompt, prompt); + // length of the Longest Common Prefix between the current slot's prompt and the input prompt + int lcp_len = common_part(slot_prompt, prompt); // fraction of the common substring length compared to the current slot's prompt length - similarity = static_cast(lcs_len) / slot_prompt_len; + similarity = static_cast(lcp_len) / slot_prompt_len; // select the current slot if the criteria match - if (lcs_len > max_lcs_len && similarity > lcs_similarity) { - max_lcs_len = lcs_len; + if (lcp_len > max_lcp_len && similarity > lcp_similarity) { + max_lcp_len = lcp_len; ret = &slot; } } if (ret != nullptr) { - LOG_VERBOSE("selected slot by lcs similarity", { + LOG_VERBOSE("selected slot by lcp similarity", { {"id_slot", ret->id}, - {"max_lcs_len", max_lcs_len}, + {"max_lcp_len", max_lcp_len}, {"similarity", similarity}, }); } @@ -2568,8 +2568,8 @@ int main(int argc, char ** argv) { log_data["api_key"] = "api_key: " + std::to_string(params.api_keys.size()) + " keys loaded"; } - // Longest Common Substring similarity for slot selection - ctx_server.lcs_similarity = params.lcs_similarity; + // Longest Common Prefix similarity for slot selection + ctx_server.lcp_similarity = params.lcp_similarity; // load the model if (!ctx_server.load_model(params)) { diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 904f5e3c08f6e..63fde9c9faabe 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -253,6 +253,13 @@ static size_t common_part(const std::vector & a, const std::vector< return i; } +static size_t common_part(const std::string & a, const std::string & b) { + size_t i; + for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} + + return i; +} + static bool ends_with(const std::string & str, const std::string & suffix) { return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); } @@ -646,52 +653,3 @@ static json format_error_response(const std::string & message, const enum error_ {"type", type_str}, }; } - -static int lcs_length(const std::string & str1, const std::string & str2) { - // check for empty strings - if (str1.empty() || str2.empty()) { - return 0; - } - - // get the lengths of the input strings - int str1_len = str1.size(); - int str2_len = str2.size(); - - // initialize the maximum length of the longest common subsequence (LCS) - int max_length = 0; - - // use two rows instead of a 2D matrix to optimize space - std::vector prev_row(str2_len + 1, 0); - std::vector curr_row(str2_len + 1, 0); - - // iterate through the characters of str1 - for (int i = 1; i <= str1_len; i++) { - // iterate through the characters of str2 - for (int j = 1; j <= str2_len; j++) { - // if characters at the current positions match - if (str1[i - 1] == str2[j - 1]) { - // if it's the first character of either string, set LCS length to 1 - if (i == 1 || j == 1) { - curr_row[j] = 1; - } else { - // increment LCS length by 1 compared to the previous character - curr_row[j] = prev_row[j - 1] + 1; - } - - // update max_length if necessary - if (curr_row[j] > max_length) { - max_length = curr_row[j]; - } - } else { - // reset LCS length if characters don't match - curr_row[j] = 0; - } - } - - // update the previous row for the next iteration - prev_row = curr_row; - } - - // return the maximum length of the LCS - return max_length; -} From a8842fdf56dc725b69c19332d46dc8bbf612069e Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Fri, 7 Jun 2024 14:27:29 +0000 Subject: [PATCH 5/5] Rename argument --- common/common.cpp | 8 ++++---- common/common.h | 2 +- examples/server/server.cpp | 12 ++++++------ 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 65448c918d98e..c829fc7921512 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1460,12 +1460,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.chat_template = argv[i]; return true; } - if (arg == "--lcp-similarity") { + if (arg == "--slot-prompt-similarity" || arg == "-sps") { if (++i >= argc) { invalid_param = true; return true; } - params.lcp_similarity = std::stof(argv[i]); + params.slot_prompt_similarity = std::stof(argv[i]); return true; } if (arg == "-pps") { @@ -1839,8 +1839,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param "set custom jinja chat template (default: template taken from model's metadata)\n" "only commonly used templates are accepted:\n" "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" }); - options.push_back({ "server", " --lcp-similarity SIMILARITY", - "how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f)\n", params.lcp_similarity }); + options.push_back({ "server", "-sps, --slot-prompt-similarity SIMILARITY", + "how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity }); #ifndef LOG_DISABLE_LOGS options.push_back({ "logging" }); diff --git a/common/common.h b/common/common.h index 0a8a9c0739a00..a093b05c4e1e2 100644 --- a/common/common.h +++ b/common/common.h @@ -202,7 +202,7 @@ struct gpt_params { std::string slot_save_path; - float lcp_similarity = 0.0f; + float slot_prompt_similarity = 0.5f; // batched-bench params bool is_pp_shared = false; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 802c660c7a005..0f3d03bc63401 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -647,8 +647,8 @@ struct server_context { server_metrics metrics; - // Longest Common Prefix similarity for slot selection - float lcp_similarity = 0.0f; + // Necessary similarity of prompt for slot selection + float slot_prompt_similarity = 0.0f; ~server_context() { if (ctx) { @@ -812,7 +812,7 @@ struct server_context { server_slot * ret = nullptr; // find the slot that has at least n% prompt similarity - if (ret == nullptr && lcp_similarity != 0.0f && !prompt.empty()) { + if (ret == nullptr && slot_prompt_similarity != 0.0f && !prompt.empty()) { int max_lcp_len = 0; float similarity = 0; @@ -840,7 +840,7 @@ struct server_context { similarity = static_cast(lcp_len) / slot_prompt_len; // select the current slot if the criteria match - if (lcp_len > max_lcp_len && similarity > lcp_similarity) { + if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) { max_lcp_len = lcp_len; ret = &slot; } @@ -2568,8 +2568,8 @@ int main(int argc, char ** argv) { log_data["api_key"] = "api_key: " + std::to_string(params.api_keys.size()) + " keys loaded"; } - // Longest Common Prefix similarity for slot selection - ctx_server.lcp_similarity = params.lcp_similarity; + // Necessary similarity of prompt for slot selection + ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; // load the model if (!ctx_server.load_model(params)) {