Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implememt k-v cache for transformers models #150

Closed
rlouf opened this issue Jun 21, 2023 · 1 comment
Closed

Implememt k-v cache for transformers models #150

rlouf opened this issue Jun 21, 2023 · 1 comment
Labels
enhancement transformers Linked to the `transformers` integration
Milestone

Comments

@rlouf
Copy link
Member

rlouf commented Jun 21, 2023

KV caching is a common optimization trick to speedup inference with the transformer architecture. Indeed, a given token only interacts with previous tokens in the attention layer, so we can cache the inputs of the attention blocks (keys and values) for all previous tokens and pass them directly during inference with a new token. It is important to get this right as it can lead to substantial performance increase. See for instance this comment.

First we need to build and persist this cache when sampling a new sequence. A linear cache works fine when generating a single sequence, but we can build something more efficient when sampling different sequences. I was originally thinking about building a trie that we query each time the model is queried, but we should also take a close look at PagedAttention.

We also want to persist the cache between generation sequences, especially for infilling workflows where we use the previous completion as a prompt. A first approach would be to cache all the text a model has ever been prompted (maybe using paged attention), but this may quickly fill the memory.

@thomasahle suggested to let users handle the caching. To do so we could make Sequence instances return a state that contains more than the completion, or a tuple (completion, extra) by default. The API could look like:

import outlines.models as models
import outlines.text as text

model = models.transformers("gpt2")

completion, extra = text.completion(model)(prompt)
_ = text.completion(model, kv_cache=extra.kv_cache)(completion)

An alternative is to pass a state to Sequence instances which contains both the completion, the KV cache and potentially other information. This abstracts away KV cache management for infilling workflows:

import outlines.models as models
import outlines.text as text

model = models.transformers("gpt2")

state = text.completion(model)("A prompt")
state = text.completion(model)(state)

print(state.completion)
# that has been completed

print(state)
# A prompt that has been completed

print(state.kv_cache)
# ...

According to @thomasahle, allowing users to pass the KV cache manually would also allow to:

  • Pass a kv_cache that has been "learned" by fine-tuning, rather than representing an actual prefix.
  • Give the model access to a vector DB of "external" KV pairs.

Questions

Shouldn't the extra KV cache be passed to the model rather than the sequence generation process? I imagine this will only work with specific model implementations?

@rlouf
Copy link
Member Author

rlouf commented Jul 13, 2023

Similar to what was suggested above, Explosion's curated-transformers takes a manual approach to caching: cache can be returned at each step and passed on to the next sequence generation.

@dottxt-ai dottxt-ai locked and limited conversation to collaborators Jul 14, 2023
@rlouf rlouf converted this issue into discussion #190 Jul 14, 2023

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Labels
enhancement transformers Linked to the `transformers` integration
Projects
None yet
Development

No branches or pull requests

1 participant