Skip to content

Commit

Permalink
speculative : fix KV cache management
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Sep 18, 2023
1 parent 7c1bdd0 commit 1f17ea6
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ int main(int argc, char ** argv) {
LOG("out of drafted tokens\n");
}

llama_kv_cache_rm_seq(ctx_dft, 0, n_past_dft, n_ctx);
llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads);
++n_past_dft;

Expand Down Expand Up @@ -217,6 +218,7 @@ int main(int argc, char ** argv) {

// sample n_draft tokens from the draft model using greedy decoding
int n_past_cur = n_past_dft;

for (int i = 0; i < n_draft; ++i) {
float * logits = llama_get_logits(ctx_dft);

Expand Down Expand Up @@ -256,6 +258,7 @@ int main(int argc, char ** argv) {
}

// evaluate the drafted token on the draft model
llama_kv_cache_rm_seq(ctx_dft, 0, n_past_cur, n_ctx);
llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads);
++n_past_cur;

Expand All @@ -265,6 +268,7 @@ int main(int argc, char ** argv) {
}

// evaluate the target model on the drafted tokens
llama_kv_cache_rm_seq(ctx_tgt, 0, n_past_tgt, n_ctx);
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads);
++n_past_tgt;

Expand Down

0 comments on commit 1f17ea6

Please sign in to comment.