This issue was moved to a discussion.
You can continue the conversation there. Go to discussion →
Implememt k-v cache for transformers
models
#150
Labels
Milestone
KV caching is a common optimization trick to speedup inference with the transformer architecture. Indeed, a given token only interacts with previous tokens in the attention layer, so we can cache the inputs of the attention blocks (keys and values) for all previous tokens and pass them directly during inference with a new token. It is important to get this right as it can lead to substantial performance increase. See for instance this comment.
First we need to build and persist this cache when sampling a new sequence. A linear cache works fine when generating a single sequence, but we can build something more efficient when sampling different sequences. I was originally thinking about building a trie that we query each time the model is queried, but we should also take a close look at PagedAttention.
We also want to persist the cache between generation sequences, especially for infilling workflows where we use the previous completion as a prompt. A first approach would be to cache all the text a model has ever been prompted (maybe using paged attention), but this may quickly fill the memory.
@thomasahle suggested to let users handle the caching. To do so we could make
Sequence
instances return a state that contains more than the completion, or a tuple(completion, extra)
by default. The API could look like:An alternative is to pass a
state
toSequence
instances which contains both the completion, the KV cache and potentially other information. This abstracts away KV cache management for infilling workflows:According to @thomasahle, allowing users to pass the KV cache manually would also allow to:
Questions
Shouldn't the extra KV cache be passed to the model rather than the sequence generation process? I imagine this will only work with specific model implementations?
The text was updated successfully, but these errors were encountered: