diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index 1172e32fd0cc..3738a4cae7b2 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -398,6 +398,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] HybridCache - update + - get_seq_length - reset [[autodoc]] SlidingWindowCache diff --git a/docs/source/en/model_doc/gemma2.md b/docs/source/en/model_doc/gemma2.md index 5befa0b1f437..431c4ecd25f2 100644 --- a/docs/source/en/model_doc/gemma2.md +++ b/docs/source/en/model_doc/gemma2.md @@ -30,6 +30,12 @@ Tips: - The original checkpoints can be converted using the conversion script `src/transformers/models/Gemma2/convert_Gemma2_weights_to_hf.py` + + +- Gemma2 uses sliding window attention every second layer, which makes it unsuitable for typical kv caching with [`~DynamicCache`] or tuples of tensors. To enable caching in Gemma2 forward call, you must initialize a [`~HybridCache`] instance and pass it as `past_key_values` to the forward call. Note, that you also have to prepare `cache_position` if the `past_key_values` already contains previous keys and values. + + + This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ), [Pedro Cuenca](https://huggingface.co/pcuenq) and [Tom Arsen](). diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 2fda315bb597..8953238186c6 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -807,7 +807,26 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + if past_key_values is None: + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + else: + raise ValueError("When `past_key_values` is passed, `cache_position` must be too") + + # Probably a forward call with caching, so we set up cache for one call only + if use_cache and past_key_values is None and not self.training: + logger.warning_once( + "You are calling the model with `use_cache=True` but didn't pass `past_key_values` while not training. ", + "If you want to compute with cache, make sure to pass an instance of `HybridCache`. An empty `HybridCache` instance " + "will be created for this call. See for more: (https://huggingface.co/docs/transformers/main/en/internal/generation_utils#transformers.HybridCache)", + ) + batch_size, seq_len, _ = inputs_embeds.shape + past_key_values = HybridCache( + self.config, + max_batch_size=batch_size, + max_cache_len=seq_len, + device=self.device, + dtype=inputs_embeds.dtype, + ) if position_ids is None: position_ids = cache_position.unsqueeze(0)