Skip to content

Commit

Permalink
Cache: add new flag to distinguish models that Cache but not static…
Browse files Browse the repository at this point in the history
… cache (#30800)

* jamba cache

* new flag

* generate exception
  • Loading branch information
gante authored May 16, 2024
1 parent 17cc71e commit 9d889f8
Show file tree
Hide file tree
Showing 19 changed files with 23 additions and 3 deletions.
5 changes: 5 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1616,6 +1616,11 @@ def generate(
"issue: https://github.com/huggingface/transformers/issues/28981."
)
if generation_config.cache_implementation == "static":
if not self._supports_static_cache:
raise ValueError(
"This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length)

self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,8 +1280,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# SDPA support
_supports_sdpa = False

# Has support for a `Cache` instance as `past_key_values`
# Has support for a `Cache` instance as `past_key_values`? Does it support a `StaticCache`?
_supports_cache_class = False
_supports_static_cache = False

@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,7 @@ class CoherePreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,7 @@ class DbrxPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True

def _init_weights(self, module: nn.Module):
std = self.config.initializer_range
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ class GemmaPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,6 +1341,7 @@ class Idefics2PreTrainedModel(PreTrainedModel):
_no_split_modules = ["Idefics2VisionAttention", "Idefics2MLP", "Idefics2PerceiverLayer", "Idefics2DecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True

def _init_weights(self, module):
# important: this ported version of Idefics2 isn't meant for training from scratch - only
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,6 +1261,7 @@ class JambaPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,7 @@ class MistralPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,7 @@ class MixtralPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,7 @@ class OlmoPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/persimmon/modeling_persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ class PersimmonPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["PersimmonDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_cache_class = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/phi/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,7 @@ class PhiPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/phi3/modeling_phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,7 @@ class Phi3PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = False
_supports_cache_class = True

_version = "0.0.5"

Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,7 @@ class Qwen2PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,6 +975,7 @@ class Qwen2MoePreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,6 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = ["cache"]
_supports_flash_attn_2 = False
_supports_sdpa = False # we can't compare with eager for now
_supports_cache_class = True

def _init_weights(self, module):
std = math.sqrt(self.config.w_init_variance_scale / self.config.conv1d_width)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/starcoder2/modeling_starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
2 changes: 1 addition & 1 deletion tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4365,7 +4365,7 @@ def test_custom_4d_attention_mask(self):
self.skipTest("Model architecture has no generative classes, and thus not necessarily supporting 4D masks")

for model_class in self.all_generative_model_classes:
if not model_class._supports_cache_class:
if not model_class._supports_static_cache:
self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks")
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config).to(device=torch_device, dtype=torch.float32)
Expand Down

0 comments on commit 9d889f8

Please sign in to comment.