-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add torch.compile Support For Mamba #31247
Add torch.compile Support For Mamba #31247
Conversation
It seems that the mamba cache is not compatible with the current cache design used in |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good!
I think for this we want a general solution that would work for hybrid caches as well (Like jamba / mamba2 / zamba / etc).
Here it's possible to init the cache before going into the forward if you set the NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache}
with "mamba"
? it's not too bad 😅
otherwise it could be that we redefiine the staticCache for mamba to be the MambaCache class.
cc @zucchini-nlp and @gante 😉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice 🔥
def is_initialized(self, layer_idx): | ||
return self.is_cache_initialized[layer_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can be checked with cache_postiions
instead no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is fine just like we need a flag in whisper, here cache_positions
is not so meaningful because we always know how to update and get the cache even if cache_positions
is not passed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Theoretically yes, but this is adding some complexity, which is not needed in the cache API. Checking the cache positions is more reliable, and is what we want to go with.
- you don't have to reset and set another tensor
which is also a win
Let's just use the cache positions
Could you share benchmark results? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use cache positions for the check, and be careful of BC!
Otherwise great work!
def is_initialized(self, layer_idx): | ||
return self.is_cache_initialized[layer_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Theoretically yes, but this is adding some complexity, which is not needed in the cache API. Checking the cache positions is more reliable, and is what we want to go with.
- you don't have to reset and set another tensor
which is also a win
Let's just use the cache positions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost done! Again great work! Let's just define a good api for futur models / for RecurrentGemma for example that will also benefit from this!
src/transformers/generation/utils.py
Outdated
elif generation_config.cache_implementation == "mamba": | ||
from ..models.mamba.modeling_mamba import MambaCache, MambaConfig | ||
|
||
if not isinstance(self.config, MambaConfig): | ||
raise ValueError( | ||
"You can only specify `cache_implementation` to `mamba` if you are using mamba model" | ||
) | ||
|
||
if hasattr(self, "_cache"): | ||
assert isinstance(self._cache, MambaCache), "Only `MambaCache` can be used on mamba model" | ||
need_new_cache = self._cache.conv_states.shape[1] != batch_size | ||
else: | ||
need_new_cache = True | ||
|
||
if need_new_cache: | ||
self._cache = MambaCache( | ||
config=self.config, batch_size=batch_size, dtype=self.dtype, device=self.device | ||
) | ||
else: | ||
self._cache.reset() | ||
model_kwargs["cache_params"] = self._cache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
THe problem with this is that it does not scale with new models. It's not something we want to do at all TBH.
The simplest is to import the MambaCache, and add it to the mapping "mamba": MambaCache
.
needs_new_cache should be specific to the cache class.
Maybe this is the best approach as for new cache class it will be a new correct way to say whether or not we reset!
self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor | ||
) -> torch.Tensor: | ||
conv_state = self.conv_states[layer_idx] | ||
cache_position = cache_position.clamp(0, self.conv_kernel_size - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bool flag should be dynamo compatible, but I trust you on this one and it's fairly small so LGTM
hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding | ||
else: | ||
|
||
if cache_position.shape[0] == self.conv_kernel_size: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be worth adding a comment in the code to explain the trick.
More in favor of using cache position[0] to detect decoding if it works, if not then a small comment!
cache_params = MambaCache( | ||
self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype | ||
) | ||
cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay. cache_postions[0] > 0
breaks full graph I gues?
input_ids = input_ids[:, -1].unsqueeze(-1) | ||
if use_cache: | ||
# `cache_position` should have been initialized in `generate` | ||
assert cache_position is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's raise an error rather than using asserts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! cc @gante if you can have a look for the generate changes!
src/transformers/generation/utils.py
Outdated
@@ -1751,7 +1758,8 @@ def generate( | |||
) | |||
|
|||
use_dynamic_cache_by_default = False | |||
if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None: | |||
cache_name = getattr(self, "cache_name", "past_key_values") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be better to set this as a class attribute, all that inherit from Cache
will have "path_key_values"
and mamba will get "cache_params"
WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
currently MambaCache
is not inherited from Cache
because of APIs of Cache
are only suitable for transformer models with kv states, so you mean make cache_name
a class attribute of Cache
and MambaCache
with values being past_key_values
and cache_params
respectively?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cache_name = "cache_params"
for mamba cache class, and cache_name = "past_key_values"
for Cache
classes !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also prefer a more verbose version for now. I spent some time looking for the cache_name
variable in this review, which is not a good indicator of readability :D
e.g.
if "mamba" in self.__class__.__name__.lower():
cache_var_name = "cache_params"
else:
cache_var_name = "past_key_values"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I guess it's a way to see if we are using mamba-related models, and there is an issue with associating cache_name
with Cache
, we need to know which cache we are creating in order to know the cache name, which brings a circular issue when we are trying to check if users are passing both cache_implementation
and a cache instance, let's go with it for now.
# we initialize the `cache_position` to full size of `conv_states` at prefill stage | ||
# considering padding will be applied when input length is shorter, and truncation | ||
# will be applied when it is longer, so it will be equivalent to always have it match | ||
# the length of `cache_params.conv_states`, which is `config.conv_kernel` | ||
cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there might be a more compile friendly way to do this, but that will be a todo. https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit#heading=h.ivdr7fmrbeab might have answers, since I do not, LGTM for now!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's just a matter of making it use data-independent ops, it could either be a flag or using shape-dependent way to see which stage it is in, I think a bool flag in forward
will also do the trick, but we have introduced cache_position
in order to address this anyway, another way of thinking this is we are kind of altering the length of hidden states by apply padding(positive or negative) before we update the cache, so we need to make sure the cache_position
is aligned with the hidden states after padding rather than before padding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generate
changes look (mostly) good to me 🤗
src/transformers/generation/utils.py
Outdated
@@ -1751,7 +1758,8 @@ def generate( | |||
) | |||
|
|||
use_dynamic_cache_by_default = False | |||
if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None: | |||
cache_name = getattr(self, "cache_name", "past_key_values") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also prefer a more verbose version for now. I spent some time looking for the cache_name
variable in this review, which is not a good indicator of readability :D
e.g.
if "mamba" in self.__class__.__name__.lower():
cache_var_name = "cache_params"
else:
cache_var_name = "past_key_values"
Looks good! |
Congrats on the merge! 🔥 |
torch.compile
support for mamba! Closes #31246