Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using abort_callback from ggml to interrupt llama computation #5409

Merged
merged 4 commits into from
Mar 2, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1948,6 +1948,9 @@ struct llama_context {
std::vector<uint8_t> buf_compute_meta;
ggml_backend_sched_t sched = nullptr;

ggml_abort_callback abort_callback = nullptr;
void * abort_callback_data = nullptr;

// input tensors
ggml_backend_buffer_t buf_input = nullptr;
ggml_context * ctx_input = nullptr;
Expand Down Expand Up @@ -7847,6 +7850,7 @@ static void llama_graph_compute(

if (lctx.backend_cpu != nullptr) {
ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data);
}

ggml_backend_sched_graph_compute(lctx.sched, gf);
Expand Down Expand Up @@ -11644,6 +11648,8 @@ struct llama_context_params llama_context_default_params() {
/*.embedding =*/ false,
/*.offload_kqv =*/ true,
/*.do_pooling =*/ true,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
};

return result;
Expand Down Expand Up @@ -11835,8 +11841,11 @@ struct llama_context * llama_new_context_with_model(
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);

ctx->rng = std::mt19937(params.seed);
ctx->logits_all = params.logits_all;
ctx->abort_callback = params.abort_callback;
ctx->abort_callback_data = params.abort_callback_data;

ctx->rng = std::mt19937(params.seed);
ctx->logits_all = params.logits_all;

const ggml_type type_k = params.type_k;
const ggml_type type_v = params.type_v;
Expand Down Expand Up @@ -12809,6 +12818,11 @@ void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_
ctx->cparams.n_threads_batch = n_threads_batch;
}

void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) {
ctx->abort_callback = abort_callback;
ctx->abort_callback_data = abort_callback_data;
}

struct llama_batch llama_batch_get_one(
llama_token * tokens,
int32_t n_tokens,
Expand Down
13 changes: 11 additions & 2 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,16 @@ extern "C" {

// Keep the booleans together to avoid misalignment during copy-by-value.
bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
bool logits_all; // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embedding; // embedding mode only
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool do_pooling; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)

// Abort callback
// if it returns true, execution of llama_decode() will be aborted
// currently works only with CPU execution
ggml_abort_callback abort_callback;
void * abort_callback_data;
};

// model quantization parameters
Expand Down Expand Up @@ -661,7 +667,10 @@ extern "C" {
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);

// Token logits obtained from the last call to llama_eval()
// Set abort callback
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);

// Token logits obtained from the last call to llama_decode()
// The logits for the last token are stored in the last row
// Logits for which llama_batch.logits[i] == 0 are undefined
// Rows: n_tokens provided with llama_batch
Expand Down
Loading