-
Notifications
You must be signed in to change notification settings - Fork 10.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Server: Improve work queue stability #5710
Changes from all commits
91e7e0f
3b2dea1
c420e05
0eac1a3
a5603de
624214a
85c0334
72a8d59
bb363b9
92671d7
fa498a2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -341,6 +341,65 @@ struct llama_client_slot | |
{"t_total", t_prompt_processing + t_token_generation}, | ||
}); | ||
} | ||
|
||
// context extension via Self-Extend | ||
void grp_attn_update_params() { | ||
int grpa_i = 0; | ||
// copy to local variables | ||
int32_t grpa_n = ga_n; | ||
int32_t grpa_w = ga_w; | ||
int32_t slot_npast = 0; | ||
for (int k = 0; k < n_past; ++k) | ||
{ | ||
while (slot_npast >= grpa_i + grpa_w) { | ||
const int bd = (grpa_w/grpa_n)*(grpa_n - 1); | ||
slot_npast -= bd; | ||
grpa_i += grpa_w/grpa_n; | ||
} | ||
slot_npast++; | ||
} | ||
n_past_se = slot_npast; | ||
ga_i = grpa_i; | ||
} | ||
|
||
int32_t grp_attn_calc_npast() { | ||
int32_t slot_npast = n_past_se > 0 ? n_past_se : n_past; | ||
// copy to local variables | ||
int32_t grpa_i = ga_i; | ||
int32_t grpa_n = ga_n; | ||
int32_t grpa_w = ga_w; | ||
while (slot_npast >= grpa_i + grpa_w) { | ||
const int bd = (grpa_w/grpa_n)*(grpa_n - 1); | ||
slot_npast -= bd; | ||
grpa_i += grpa_w/grpa_n; | ||
} | ||
return slot_npast; | ||
} | ||
|
||
void grp_attn_shift(llama_context * ctx, const int32_t n_tokens) { | ||
while (n_past_se >= ga_i + ga_w) | ||
{ | ||
const int ib = (ga_n * ga_i) / ga_w; | ||
const int bd = (ga_w / ga_n) * (ga_n - 1); | ||
const int dd = (ga_w / ga_n) - ib * bd - ga_w; | ||
|
||
LOG_TEE("\n"); | ||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i, n_past_se, ib * bd, ga_i + ib * bd, n_past_se + ib * bd); | ||
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib * bd, ga_i + ib * bd + ga_w, ga_n, (ga_i + ib * bd) / ga_n, (ga_i + ib * bd + ga_w) / ga_n); | ||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib * bd + ga_w, n_past_se + ib * bd, dd, ga_i + ib * bd + ga_w + dd, n_past_se + ib * bd + dd); | ||
|
||
llama_kv_cache_seq_add(ctx, id, ga_i, n_past_se, ib * bd); | ||
llama_kv_cache_seq_div(ctx, id, ga_i + ib * bd, ga_i + ib * bd + ga_w,ga_n); | ||
llama_kv_cache_seq_add(ctx, id, ga_i + ib * bd + ga_w,n_past_se + ib * bd, dd); | ||
|
||
n_past_se -= bd; | ||
|
||
ga_i += ga_w / ga_n; | ||
|
||
LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", n_past_se + bd, n_past_se, ga_i); | ||
} | ||
n_past_se += n_tokens; | ||
} | ||
}; | ||
|
||
struct llama_metrics { | ||
|
@@ -1120,13 +1179,23 @@ struct llama_server_context | |
return slot.images.size() > 0; | ||
} | ||
|
||
void send_error(task_server& task, const std::string &error) | ||
void send_error(task_server &task, const std::string &error) | ||
{ | ||
LOG_TEE("task %i - error: %s\n", task.id, error.c_str()); | ||
send_error(task.id, task.multitask_id, error); | ||
} | ||
|
||
void send_error(llama_client_slot &slot, const std::string &error) | ||
{ | ||
send_error(slot.task_id, slot.multitask_id, error); | ||
} | ||
|
||
void send_error(int task_id, int multitask_id, const std::string &error) | ||
{ | ||
LOG_TEE("task %i - error: %s\n", task_id, error.c_str()); | ||
task_result res; | ||
res.id = task.id; | ||
res.multitask_id = task.multitask_id; | ||
res.stop = false; | ||
res.id = task_id; | ||
res.multitask_id = multitask_id; | ||
res.stop = true; | ||
res.error = true; | ||
res.result_json = { { "content", error } }; | ||
queue_results.send(res); | ||
|
@@ -1593,7 +1662,9 @@ struct llama_server_context | |
queue_results.send(result); | ||
} | ||
|
||
bool update_slots() { | ||
void run_slots() { | ||
bool has_next_response = false; // whether to schedule next slot run, to generate next token | ||
|
||
if (system_need_update) | ||
{ | ||
LOG_INFO("updating system prompt", {}); | ||
|
@@ -1609,15 +1680,9 @@ struct llama_server_context | |
LOG_INFO("all slots are idle and system prompt is empty, clear the KV cache", {}); | ||
kv_cache_clear(); | ||
} | ||
return true; | ||
return; | ||
} | ||
|
||
LOG_VERBOSE("posting NEXT_RESPONSE", {}); | ||
task_server task; | ||
task.type = TASK_TYPE_NEXT_RESPONSE; | ||
task.target_id = -1; | ||
queue_tasks.post(task); | ||
|
||
for (llama_client_slot &slot : slots) | ||
{ | ||
if (slot.ga_n == 1) | ||
|
@@ -1815,21 +1880,8 @@ struct llama_server_context | |
|
||
if (slot.ga_n != 1) | ||
{ | ||
int ga_i = 0; | ||
int32_t ga_n = slot.ga_n; | ||
int32_t ga_w = slot.ga_w; | ||
int32_t slot_npast = 0; | ||
for (int k = 0; k < slot.n_past; ++k) | ||
{ | ||
while (slot_npast >= ga_i + ga_w) { | ||
const int bd = (ga_w/ga_n)*(ga_n - 1); | ||
slot_npast -= bd; | ||
ga_i += ga_w/ga_n; | ||
} | ||
slot_npast++; | ||
} | ||
slot.n_past_se = slot_npast; | ||
slot.ga_i = ga_i; | ||
// context extension via Self-Extend | ||
slot.grp_attn_update_params(); | ||
} | ||
|
||
LOG_INFO("slot progression", { | ||
|
@@ -1875,22 +1927,16 @@ struct llama_server_context | |
// process the prefix of first image | ||
std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens; | ||
|
||
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past; | ||
|
||
int32_t ga_i = slot.ga_i; | ||
int32_t ga_n = slot.ga_n; | ||
int32_t ga_w = slot.ga_w; | ||
int32_t slot_npast = slot.n_past; | ||
|
||
for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past) | ||
{ | ||
if (slot.ga_n != 1) | ||
{ | ||
while (slot_npast >= ga_i + ga_w) { | ||
const int bd = (ga_w/ga_n)*(ga_n - 1); | ||
slot_npast -= bd; | ||
ga_i += ga_w/ga_n; | ||
} | ||
// context extension via Self-Extend | ||
slot_npast = slot.grp_attn_calc_npast(); | ||
} | ||
|
||
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false); | ||
slot_npast++; | ||
} | ||
|
@@ -1901,10 +1947,8 @@ struct llama_server_context | |
"slot_id", slot.id, | ||
"task_id", slot.task_id, | ||
}); | ||
// FIXME @phymbert: to be properly tested | ||
// early returning without changing the slot state will block the slot for ever | ||
// no one at the moment is checking the return value | ||
return false; | ||
send_error(slot, "failed processing images"); | ||
continue; | ||
} | ||
|
||
// extract the logits only for the last token | ||
|
@@ -1922,9 +1966,9 @@ struct llama_server_context | |
if (batch.n_tokens == 0) | ||
{ | ||
all_slots_are_idle = true; | ||
return true; | ||
} | ||
|
||
// loop of n_batch | ||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) | ||
{ | ||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); | ||
|
@@ -1934,28 +1978,9 @@ struct llama_server_context | |
if (slot.ga_n != 1) | ||
{ | ||
// context extension via Self-Extend | ||
while (slot.n_past_se >= slot.ga_i + slot.ga_w) | ||
{ | ||
const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w; | ||
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1); | ||
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w; | ||
|
||
LOG_TEE("\n"); | ||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd); | ||
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); | ||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd); | ||
|
||
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd); | ||
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n); | ||
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd); | ||
|
||
slot.n_past_se -= bd; | ||
|
||
slot.ga_i += slot.ga_w / slot.ga_n; | ||
|
||
LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i); | ||
} | ||
slot.n_past_se += n_tokens; | ||
// TODO @ngxson: What happen if we're retrying with smaller n_batch? | ||
// By the second time we retry, "grp_attn_shift" has already been called | ||
slot.grp_attn_shift(ctx, n_tokens); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ggerganov I noticed a potential bug here where (The group attention mechanism is still too complex for me to really understand, I'm not sure what I'm doing here is correct or not) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll try to reimplement the self-extend logic in the following days. Even if there is a bug here, we'll fix it in the new implementation, so don't worry for now Btw, it would be very useful to add a passkey test that works with server with extended context. This is the command that we run using the make -j && ./passkey ./models/llama-7b/ggml-model-f16.gguf 250 4 50 This generates a prompt of about ~6k tokens and puts a number (the "pass key") at the start. It uses self-extend with a factor of 4, so that even a 2k model like LLaMA v1 will be able to recall it. The test would be too heavy for the GH CI, so it should only run locally. Probably a simple cc @phymbert There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the confirmation. I'll leave my TODO here so that I can look into it in the future. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI, I also remove some changes to make this PR smaller, because initially I wanted this PR to be more about "fixing bugs" than "refectoring" |
||
} | ||
} | ||
|
||
|
@@ -1979,22 +2004,29 @@ struct llama_server_context | |
{ | ||
// if you get here, it means the KV cache is full - try increasing it via the context size | ||
LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret); | ||
return false; | ||
for (auto & slot : slots) | ||
{ | ||
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size."); | ||
slot.release(); | ||
} | ||
has_next_response = false; | ||
break; // break loop of n_batch | ||
} | ||
|
||
LOG_TEE("%s : failed to find free space in the KV cache, retrying with smaller n_batch = %d\n", __func__, n_batch / 2); | ||
|
||
// retry with half the batch size to try to find a free slot in the KV cache | ||
n_batch /= 2; | ||
i -= n_batch; | ||
continue; | ||
continue; // continue loop of n_batch | ||
} | ||
|
||
// loop of slots | ||
for (auto & slot : slots) | ||
{ | ||
if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) | ||
{ | ||
continue; | ||
continue; // continue loop of slots | ||
} | ||
|
||
// prompt evaluated for embedding | ||
|
@@ -2003,7 +2035,7 @@ struct llama_server_context | |
send_embedding(slot); | ||
slot.release(); | ||
slot.i_batch = -1; | ||
continue; | ||
continue; // continue loop of slots | ||
} | ||
|
||
completion_token_output result; | ||
|
@@ -2042,16 +2074,25 @@ struct llama_server_context | |
metrics.on_prediction(slot); | ||
} | ||
|
||
// if slot is not yet finish its work, we schedule next run | ||
if (slot.has_next_token) | ||
{ | ||
has_next_response = true; | ||
} | ||
|
||
slot.i_batch = -1; | ||
} | ||
} | ||
|
||
LOG_VERBOSE("slots updated", {}); | ||
return true; | ||
} | ||
if (has_next_response) { | ||
LOG_VERBOSE("schedule next slot run", {}); | ||
task_server task; | ||
task.type = TASK_TYPE_NEXT_RESPONSE; | ||
task.target_id = -1; | ||
queue_tasks.post(task); | ||
} | ||
|
||
void run_on_all_tasks_finished() { | ||
update_slots(); | ||
LOG_VERBOSE("slots run completed", {}); | ||
} | ||
}; | ||
|
||
|
@@ -3494,7 +3535,7 @@ int main(int argc, char **argv) | |
bool running = true; | ||
while (running) | ||
{ | ||
running = llama.update_slots(); | ||
running = llama.run_slots(); | ||
} | ||
}*/ | ||
//); | ||
|
@@ -3516,8 +3557,8 @@ int main(int argc, char **argv) | |
&llama_server_context::process_single_task, &llama, std::placeholders::_1)); | ||
llama.queue_tasks.on_finish_multitask(std::bind( | ||
&llama_server_context::on_finish_multitask, &llama, std::placeholders::_1)); | ||
llama.queue_tasks.on_all_tasks_finished(std::bind( | ||
&llama_server_context::run_on_all_tasks_finished, &llama)); | ||
llama.queue_tasks.on_run_slots(std::bind( | ||
&llama_server_context::run_slots, &llama)); | ||
llama.queue_results.on_multitask_update(std::bind( | ||
&llama_server_queue::update_multitask, | ||
&llama.queue_tasks, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this change be backported somewhere else ?