Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Server: Improve work queue stability #5710

Closed
wants to merge 11 commits into from
195 changes: 118 additions & 77 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,65 @@ struct llama_client_slot
{"t_total", t_prompt_processing + t_token_generation},
});
}

// context extension via Self-Extend
void grp_attn_update_params() {
int grpa_i = 0;
// copy to local variables
int32_t grpa_n = ga_n;
int32_t grpa_w = ga_w;
int32_t slot_npast = 0;
for (int k = 0; k < n_past; ++k)
{
while (slot_npast >= grpa_i + grpa_w) {
const int bd = (grpa_w/grpa_n)*(grpa_n - 1);
slot_npast -= bd;
grpa_i += grpa_w/grpa_n;
}
slot_npast++;
}
n_past_se = slot_npast;
ga_i = grpa_i;
}

int32_t grp_attn_calc_npast() {
int32_t slot_npast = n_past_se > 0 ? n_past_se : n_past;
// copy to local variables
int32_t grpa_i = ga_i;
int32_t grpa_n = ga_n;
int32_t grpa_w = ga_w;
while (slot_npast >= grpa_i + grpa_w) {
const int bd = (grpa_w/grpa_n)*(grpa_n - 1);
slot_npast -= bd;
grpa_i += grpa_w/grpa_n;
}
return slot_npast;
}

void grp_attn_shift(llama_context * ctx, const int32_t n_tokens) {
while (n_past_se >= ga_i + ga_w)
{
const int ib = (ga_n * ga_i) / ga_w;
const int bd = (ga_w / ga_n) * (ga_n - 1);
const int dd = (ga_w / ga_n) - ib * bd - ga_w;

LOG_TEE("\n");
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i, n_past_se, ib * bd, ga_i + ib * bd, n_past_se + ib * bd);
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib * bd, ga_i + ib * bd + ga_w, ga_n, (ga_i + ib * bd) / ga_n, (ga_i + ib * bd + ga_w) / ga_n);
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib * bd + ga_w, n_past_se + ib * bd, dd, ga_i + ib * bd + ga_w + dd, n_past_se + ib * bd + dd);

llama_kv_cache_seq_add(ctx, id, ga_i, n_past_se, ib * bd);
llama_kv_cache_seq_div(ctx, id, ga_i + ib * bd, ga_i + ib * bd + ga_w,ga_n);
llama_kv_cache_seq_add(ctx, id, ga_i + ib * bd + ga_w,n_past_se + ib * bd, dd);

n_past_se -= bd;

ga_i += ga_w / ga_n;

LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", n_past_se + bd, n_past_se, ga_i);
}
n_past_se += n_tokens;
}
};

struct llama_metrics {
Expand Down Expand Up @@ -1120,13 +1179,23 @@ struct llama_server_context
return slot.images.size() > 0;
}

void send_error(task_server& task, const std::string &error)
void send_error(task_server &task, const std::string &error)
{
LOG_TEE("task %i - error: %s\n", task.id, error.c_str());
send_error(task.id, task.multitask_id, error);
}

void send_error(llama_client_slot &slot, const std::string &error)
{
send_error(slot.task_id, slot.multitask_id, error);
}

void send_error(int task_id, int multitask_id, const std::string &error)
{
LOG_TEE("task %i - error: %s\n", task_id, error.c_str());
task_result res;
res.id = task.id;
res.multitask_id = task.multitask_id;
res.stop = false;
res.id = task_id;
res.multitask_id = multitask_id;
res.stop = true;
res.error = true;
res.result_json = { { "content", error } };
queue_results.send(res);
Expand Down Expand Up @@ -1593,7 +1662,9 @@ struct llama_server_context
queue_results.send(result);
}

bool update_slots() {
void run_slots() {
bool has_next_response = false; // whether to schedule next slot run, to generate next token

if (system_need_update)
{
LOG_INFO("updating system prompt", {});
Expand All @@ -1609,15 +1680,9 @@ struct llama_server_context
LOG_INFO("all slots are idle and system prompt is empty, clear the KV cache", {});
kv_cache_clear();
}
return true;
return;
}

LOG_VERBOSE("posting NEXT_RESPONSE", {});
task_server task;
task.type = TASK_TYPE_NEXT_RESPONSE;
task.target_id = -1;
queue_tasks.post(task);

for (llama_client_slot &slot : slots)
{
if (slot.ga_n == 1)
Expand Down Expand Up @@ -1815,21 +1880,8 @@ struct llama_server_context

if (slot.ga_n != 1)
{
int ga_i = 0;
int32_t ga_n = slot.ga_n;
int32_t ga_w = slot.ga_w;
int32_t slot_npast = 0;
for (int k = 0; k < slot.n_past; ++k)
{
while (slot_npast >= ga_i + ga_w) {
const int bd = (ga_w/ga_n)*(ga_n - 1);
slot_npast -= bd;
ga_i += ga_w/ga_n;
}
slot_npast++;
}
slot.n_past_se = slot_npast;
slot.ga_i = ga_i;
// context extension via Self-Extend
slot.grp_attn_update_params();
}

LOG_INFO("slot progression", {
Expand Down Expand Up @@ -1875,22 +1927,16 @@ struct llama_server_context
// process the prefix of first image
std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens;

int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;

int32_t ga_i = slot.ga_i;
int32_t ga_n = slot.ga_n;
int32_t ga_w = slot.ga_w;
int32_t slot_npast = slot.n_past;

for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past)
{
if (slot.ga_n != 1)
{
while (slot_npast >= ga_i + ga_w) {
const int bd = (ga_w/ga_n)*(ga_n - 1);
slot_npast -= bd;
ga_i += ga_w/ga_n;
}
// context extension via Self-Extend
slot_npast = slot.grp_attn_calc_npast();
}

llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
slot_npast++;
}
Expand All @@ -1901,10 +1947,8 @@ struct llama_server_context
"slot_id", slot.id,
"task_id", slot.task_id,
});
// FIXME @phymbert: to be properly tested
// early returning without changing the slot state will block the slot for ever
// no one at the moment is checking the return value
return false;
send_error(slot, "failed processing images");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this change be backported somewhere else ?

continue;
}

// extract the logits only for the last token
Expand All @@ -1922,9 +1966,9 @@ struct llama_server_context
if (batch.n_tokens == 0)
{
all_slots_are_idle = true;
return true;
}

// loop of n_batch
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch)
{
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
Expand All @@ -1934,28 +1978,9 @@ struct llama_server_context
if (slot.ga_n != 1)
{
// context extension via Self-Extend
while (slot.n_past_se >= slot.ga_i + slot.ga_w)
{
const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;

LOG_TEE("\n");
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);

llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);

slot.n_past_se -= bd;

slot.ga_i += slot.ga_w / slot.ga_n;

LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
}
slot.n_past_se += n_tokens;
// TODO @ngxson: What happen if we're retrying with smaller n_batch?
// By the second time we retry, "grp_attn_shift" has already been called
slot.grp_attn_shift(ctx, n_tokens);
Copy link
Collaborator Author

@ngxson ngxson Feb 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ggerganov I noticed a potential bug here where llama_kv_cache_seq_shift and llama_kv_cache_seq_div may be called multiple times when we retry llama_decode with different batch size. Can you please have a look to see if that's the case? Thanks.

(The group attention mechanism is still too complex for me to really understand, I'm not sure what I'm doing here is correct or not)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try to reimplement the self-extend logic in the following days. Even if there is a bug here, we'll fix it in the new implementation, so don't worry for now

Btw, it would be very useful to add a passkey test that works with server with extended context. This is the command that we run using the passkey example:

make -j && ./passkey ./models/llama-7b/ggml-model-f16.gguf 250 4 50

This generates a prompt of about ~6k tokens and puts a number (the "pass key") at the start. It uses self-extend with a factor of 4, so that even a 2k model like LLaMA v1 will be able to recall it.

The test would be too heavy for the GH CI, so it should only run locally. Probably a simple curl command that sends a similar prompt as the example I've shown above. It could even be a multi-user test, so we can test that self-extend works with more than one prompts in parallel

cc @phymbert

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the confirmation. I'll leave my TODO here so that I can look into it in the future.

Copy link
Collaborator Author

@ngxson ngxson Feb 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, I also remove some changes to make this PR smaller, because initially I wanted this PR to be more about "fixing bugs" than "refectoring"

}
}

Expand All @@ -1979,22 +2004,29 @@ struct llama_server_context
{
// if you get here, it means the KV cache is full - try increasing it via the context size
LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
return false;
for (auto & slot : slots)
{
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
slot.release();
}
has_next_response = false;
break; // break loop of n_batch
}

LOG_TEE("%s : failed to find free space in the KV cache, retrying with smaller n_batch = %d\n", __func__, n_batch / 2);

// retry with half the batch size to try to find a free slot in the KV cache
n_batch /= 2;
i -= n_batch;
continue;
continue; // continue loop of n_batch
}

// loop of slots
for (auto & slot : slots)
{
if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens))
{
continue;
continue; // continue loop of slots
}

// prompt evaluated for embedding
Expand All @@ -2003,7 +2035,7 @@ struct llama_server_context
send_embedding(slot);
slot.release();
slot.i_batch = -1;
continue;
continue; // continue loop of slots
}

completion_token_output result;
Expand Down Expand Up @@ -2042,16 +2074,25 @@ struct llama_server_context
metrics.on_prediction(slot);
}

// if slot is not yet finish its work, we schedule next run
if (slot.has_next_token)
{
has_next_response = true;
}

slot.i_batch = -1;
}
}

LOG_VERBOSE("slots updated", {});
return true;
}
if (has_next_response) {
LOG_VERBOSE("schedule next slot run", {});
task_server task;
task.type = TASK_TYPE_NEXT_RESPONSE;
task.target_id = -1;
queue_tasks.post(task);
}

void run_on_all_tasks_finished() {
update_slots();
LOG_VERBOSE("slots run completed", {});
}
};

Expand Down Expand Up @@ -3494,7 +3535,7 @@ int main(int argc, char **argv)
bool running = true;
while (running)
{
running = llama.update_slots();
running = llama.run_slots();
}
}*/
//);
Expand All @@ -3516,8 +3557,8 @@ int main(int argc, char **argv)
&llama_server_context::process_single_task, &llama, std::placeholders::_1));
llama.queue_tasks.on_finish_multitask(std::bind(
&llama_server_context::on_finish_multitask, &llama, std::placeholders::_1));
llama.queue_tasks.on_all_tasks_finished(std::bind(
&llama_server_context::run_on_all_tasks_finished, &llama));
llama.queue_tasks.on_run_slots(std::bind(
&llama_server_context::run_slots, &llama));
llama.queue_results.on_multitask_update(std::bind(
&llama_server_queue::update_multitask,
&llama.queue_tasks,
Expand Down
Loading
Loading