Skip to content

Commit

Permalink
llama : fix edge case finding batch seq_id of split recurrent cell
Browse files Browse the repository at this point in the history
This otherwise was a problem when running the HellaSwag benchmark
with small batch sizes, making it crash.
  • Loading branch information
compilade committed Jun 1, 2024
1 parent 18d1c14 commit 61200ef
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3879,11 +3879,17 @@ static bool llama_cache_find_slot(
if (cell.tail_rc == 0) {
cache.rs.clear_cell(cell);
} else {
// TODO: does this always work correctly
// even if there are more than one seq_node in this cell?
// Find the seq_id of the first tail of this cell
llama_seq_id seq_id = -1;
for (llama_rs_seq_node & seq_node : cell.seq_nodes) {
if (seq_node.is_tail()) {
seq_id = seq_node.seq_id;
break;
}
}
GGML_ASSERT(seq_id != -1);

// Which seq_id of the batch is it?
llama_seq_id seq_id = cell.seq_nodes[0].seq_id;
int32_t nth_seq_id = -1;
for (int32_t s = 0; (uint32_t) s < n_seqs; ++s) {
if (seq_id == batch.seq_id[s][0]) {
Expand Down

0 comments on commit 61200ef

Please sign in to comment.