From ecc7e438570d85586e65db98ddfc32b231dcfa85 Mon Sep 17 00:00:00 2001 From: raushan Date: Fri, 26 Jul 2024 09:28:08 +0200 Subject: [PATCH 1/8] gemma2 fallback to dynamic cache --- .../models/gemma2/modeling_gemma2.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 10d00fa460ba..942d00098460 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -27,7 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -744,8 +744,24 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: + if past_key_values is not None and not past_key_values.empty: + logger.warning( + "You are calling the model with non-empty `past_key_values` but didn't pass `cache_position`. ", + "This will results in incorrect logits. Please pass `cache_position` that indicates indicates " + "input's positions in the sequence.", + ) cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + # Probably a forward call with caching, so we set up cache for one call only + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + logger.warning( + "You are calling the model with `use_cache=True` but didn't pass `past_key_values`. ", + "Be default the model will use a dynamic cache, which doesn't support sliding window. " + "In case you are calling iteratively to generate, please initiate `past_key_values` " + "outside as `HybridCache` class.", + ) + if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -829,7 +845,7 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if past_key_values is not None: + if past_key_values is not None and past_key_values.get_max_length() is not None: target_length = past_key_values.get_max_length() else: target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] From 7c27d111abff82e2b000e3e01541d91d6bb657b2 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Tue, 30 Jul 2024 10:09:35 +0500 Subject: [PATCH 2/8] Update src/transformers/models/gemma2/modeling_gemma2.py Co-authored-by: Joao Gante --- src/transformers/models/gemma2/modeling_gemma2.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index cdc02a33b0e2..68130786ab6c 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -747,13 +747,10 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - if past_key_values is not None and not past_key_values.empty: - logger.warning( - "You are calling the model with non-empty `past_key_values` but didn't pass `cache_position`. ", - "This will results in incorrect logits. Please pass `cache_position` that indicates indicates " - "input's positions in the sequence.", - ) - cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + if past_key_values is None: + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + else: + raise ValueError("When `past_key_values` is passed, `cache_position` must be too") # Probably a forward call with caching, so we set up cache for one call only if use_cache and past_key_values is None: From ff62a4f99eb694e2a895d3c8023e5c9737e4c9cc Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 31 Jul 2024 18:28:18 +0500 Subject: [PATCH 3/8] Update src/transformers/models/gemma2/modeling_gemma2.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/models/gemma2/modeling_gemma2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 68130786ab6c..62e3a5fcc690 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -753,7 +753,7 @@ def forward( raise ValueError("When `past_key_values` is passed, `cache_position` must be too") # Probably a forward call with caching, so we set up cache for one call only - if use_cache and past_key_values is None: + if use_cache and past_key_values is None and not self.training: past_key_values = DynamicCache() logger.warning( "You are calling the model with `use_cache=True` but didn't pass `past_key_values`. ", From f34c8192f02e371f1a0bdf51351da2db2da0a72b Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 1 Aug 2024 11:48:29 +0200 Subject: [PATCH 4/8] raise error and dont fallback to dynamic cache --- docs/source/en/internal/generation_utils.md | 9 +++++++++ src/transformers/models/gemma2/modeling_gemma2.py | 10 ++++------ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index da7ea25e54b6..fbd3dee6b523 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -391,6 +391,15 @@ A [`Constraint`] can be used to force the generation to include specific tokens - get_seq_length - reset +[[autodoc]] HybridCache + - update + - get_seq_length + - reset + +[[autodoc]] SlidingWindowCache + - update + - reset + [[autodoc]] EncoderDecoderCache - get_seq_length - to_legacy_cache diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 62e3a5fcc690..8bd685565ff4 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -27,7 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache +from ...cache_utils import Cache from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -754,12 +754,10 @@ def forward( # Probably a forward call with caching, so we set up cache for one call only if use_cache and past_key_values is None and not self.training: - past_key_values = DynamicCache() - logger.warning( + raise ValueError( "You are calling the model with `use_cache=True` but didn't pass `past_key_values`. ", - "Be default the model will use a dynamic cache, which doesn't support sliding window. " - "In case you are calling iteratively to generate, please initiate `past_key_values` " - "outside as `HybridCache` class.", + "Make sure to pass an instance of `HybridCache`. See for more: " + "(https://huggingface.co/docs/transformers/main/en/internal/generation_utils#transformers.HybridCache)", ) if position_ids is None: From 26dbe828bfc517c1d3518d4b4e1d1344a56fa728 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 1 Aug 2024 12:11:35 +0200 Subject: [PATCH 5/8] prev will break most forward calls/tests --- src/transformers/models/gemma2/modeling_gemma2.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 8bd685565ff4..12921a43f0ed 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -754,11 +754,12 @@ def forward( # Probably a forward call with caching, so we set up cache for one call only if use_cache and past_key_values is None and not self.training: - raise ValueError( + logger.warning_once( "You are calling the model with `use_cache=True` but didn't pass `past_key_values`. ", - "Make sure to pass an instance of `HybridCache`. See for more: " + "Make sure to pass an instance of `HybridCache`. Caching will be disabled. See for more: " "(https://huggingface.co/docs/transformers/main/en/internal/generation_utils#transformers.HybridCache)", ) + use_cache = False if position_ids is None: position_ids = cache_position.unsqueeze(0) From a69c06b0defec68b9ee19ddc6edd24d172177431 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Mon, 5 Aug 2024 11:38:39 +0500 Subject: [PATCH 6/8] Update src/transformers/models/gemma2/modeling_gemma2.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/models/gemma2/modeling_gemma2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index bade3dcb5755..226325cb1d5f 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -815,8 +815,8 @@ def forward( # Probably a forward call with caching, so we set up cache for one call only if use_cache and past_key_values is None and not self.training: logger.warning_once( - "You are calling the model with `use_cache=True` but didn't pass `past_key_values`. ", - "Make sure to pass an instance of `HybridCache`. Caching will be disabled. See for more: " + "You are calling the model with `use_cache=True` but didn't pass `past_key_values` while not training. ", + "If you want to compute with cache, make sure to pass an instance of `HybridCache`. Caching will be disabled otherwise. See for more: " "(https://huggingface.co/docs/transformers/main/en/internal/generation_utils#transformers.HybridCache)", ) use_cache = False From 240500d2f198e9b768ece914efbf77d6e3b5af7f Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 5 Aug 2024 08:56:04 +0200 Subject: [PATCH 7/8] update --- docs/source/en/model_doc/gemma2.md | 6 ++++++ src/transformers/__init__.py | 4 ++++ src/transformers/models/gemma2/modeling_gemma2.py | 13 ++++++++++--- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/docs/source/en/model_doc/gemma2.md b/docs/source/en/model_doc/gemma2.md index 5befa0b1f437..431c4ecd25f2 100644 --- a/docs/source/en/model_doc/gemma2.md +++ b/docs/source/en/model_doc/gemma2.md @@ -30,6 +30,12 @@ Tips: - The original checkpoints can be converted using the conversion script `src/transformers/models/Gemma2/convert_Gemma2_weights_to_hf.py` + + +- Gemma2 uses sliding window attention every second layer, which makes it unsuitable for typical kv caching with [`~DynamicCache`] or tuples of tensors. To enable caching in Gemma2 forward call, you must initialize a [`~HybridCache`] instance and pass it as `past_key_values` to the forward call. Note, that you also have to prepare `cache_position` if the `past_key_values` already contains previous keys and values. + + + This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ), [Pedro Cuenca](https://huggingface.co/pcuenq) and [Tom Arsen](). diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 4c953bab6be4..92364ffa0529 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1226,10 +1226,12 @@ "DynamicCache", "EncoderDecoderCache", "HQQQuantizedCache", + "HybridCache", "QuantizedCache", "QuantizedCacheConfig", "QuantoQuantizedCache", "SinkCache", + "SlidingWindowCache", "StaticCache", ] _import_structure["data.datasets"] = [ @@ -5948,10 +5950,12 @@ DynamicCache, EncoderDecoderCache, HQQQuantizedCache, + HybridCache, QuantizedCache, QuantizedCacheConfig, QuantoQuantizedCache, SinkCache, + SlidingWindowCache, StaticCache, ) from .data.datasets import ( diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 226325cb1d5f..8953238186c6 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -816,10 +816,17 @@ def forward( if use_cache and past_key_values is None and not self.training: logger.warning_once( "You are calling the model with `use_cache=True` but didn't pass `past_key_values` while not training. ", - "If you want to compute with cache, make sure to pass an instance of `HybridCache`. Caching will be disabled otherwise. See for more: " - "(https://huggingface.co/docs/transformers/main/en/internal/generation_utils#transformers.HybridCache)", + "If you want to compute with cache, make sure to pass an instance of `HybridCache`. An empty `HybridCache` instance " + "will be created for this call. See for more: (https://huggingface.co/docs/transformers/main/en/internal/generation_utils#transformers.HybridCache)", + ) + batch_size, seq_len, _ = inputs_embeds.shape + past_key_values = HybridCache( + self.config, + max_batch_size=batch_size, + max_cache_len=seq_len, + device=self.device, + dtype=inputs_embeds.dtype, ) - use_cache = False if position_ids is None: position_ids = cache_position.unsqueeze(0) From 724c62bc7e44e1237c6f7d9e50da9d534cd5579f Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 5 Aug 2024 08:57:27 +0200 Subject: [PATCH 8/8] fix copies --- src/transformers/utils/dummy_pt_objects.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index de739c6e7004..21eb32a21e46 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -51,6 +51,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class HybridCache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class QuantizedCache(metaclass=DummyObject): _backends = ["torch"] @@ -79,6 +86,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class SlidingWindowCache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class StaticCache(metaclass=DummyObject): _backends = ["torch"]