From d6fd2d20dd963ad22f8173ac0c0f8a671bca67a1 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 24 Jul 2024 09:38:31 +0200 Subject: [PATCH 1/3] let's not warn when someone is running a foward without cache + self.training --- src/transformers/models/llama/modeling_llama.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index adb455acfbbc..d553ce4432f9 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -894,7 +894,9 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( From a31f82ab8067e2a3f4b730421e1e403020e75479 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 24 Jul 2024 09:43:08 +0200 Subject: [PATCH 2/3] more models --- src/transformers/models/dbrx/modeling_dbrx.py | 2 +- src/transformers/models/mistral/modeling_mistral.py | 2 +- src/transformers/models/mixtral/modeling_mixtral.py | 2 +- src/transformers/models/olmo/modeling_olmo.py | 2 +- src/transformers/models/persimmon/modeling_persimmon.py | 2 +- src/transformers/models/phi/modeling_phi.py | 2 +- src/transformers/models/phi3/modeling_phi3.py | 2 +- src/transformers/models/qwen2/modeling_qwen2.py | 2 +- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 2 +- src/transformers/models/stablelm/modeling_stablelm.py | 2 +- src/transformers/models/starcoder2/modeling_starcoder2.py | 2 +- 11 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index e3be8decbc6b..a80016587764 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1005,7 +1005,7 @@ def forward( inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training) return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + if use_cache and not isinstance(past_key_values, Cache) and not self.training: # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 93a60a49dbf3..7c339271a589 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -758,7 +758,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and not isinstance(past_key_values, Cache) and not self.training: past_key_values = DynamicCache.from_legacy_cache(past_key_values) return_legacy_cache = True logger.warning_once( diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index d2ee6e6b268a..7df175a0467d 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -960,7 +960,7 @@ def forward( use_cache = False use_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and not isinstance(past_key_values, Cache) and not self.training: use_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 74d49d5606c1..75b674bf633f 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -811,7 +811,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + if use_cache and not isinstance(past_key_values, Cache) and not self.training: # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index af22145e3e9d..c718b7a40633 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -626,7 +626,7 @@ def forward( use_cache = False use_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and not isinstance(past_key_values, Cache) and not self.training: use_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 1b23be39e5c0..1289910381e5 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -909,7 +909,7 @@ def forward( use_cache = False use_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and not isinstance(past_key_values, Cache) and not self.training: use_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 90b815184b07..dfcb7c2dd009 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -950,7 +950,7 @@ def forward( use_cache = False use_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and not isinstance(past_key_values, Cache) and not self.training: use_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 1ff8896ae5f9..67c8c71fe59d 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -808,7 +808,7 @@ def forward( use_cache = False use_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and not isinstance(past_key_values, Cache) and not self.training: use_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 54e91da3347d..9b50235c15d5 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -970,7 +970,7 @@ def forward( use_cache = False use_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and not isinstance(past_key_values, Cache) and not self.training: use_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 3a3b6a9e05f1..086f525644b2 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -902,7 +902,7 @@ def forward( use_cache = False use_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and not isinstance(past_key_values, Cache) and not self.training: use_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index f2786f9df48a..18dbe510450f 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -784,7 +784,7 @@ def forward( use_cache = False use_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): + if use_cache and not isinstance(past_key_values, Cache) and not self.training: use_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( From cc7c1ddb511f21239c7be1e32e05928829d2bfec Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 24 Jul 2024 09:43:26 +0200 Subject: [PATCH 3/3] fixup --- src/transformers/models/cohere/modeling_cohere.py | 4 +++- src/transformers/models/dbrx/modeling_dbrx.py | 4 +++- src/transformers/models/gemma/diff_gemma.py | 4 +++- src/transformers/models/gemma/modeling_gemma.py | 8 ++++++-- src/transformers/models/jetmoe/modeling_jetmoe.py | 4 +++- src/transformers/models/olmo/modeling_olmo.py | 4 +++- 6 files changed, 21 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 6532c656d453..6257eeb9958c 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -769,7 +769,9 @@ def forward( past_seen_tokens = 0 return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index a80016587764..b1f3ce1b8ba9 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1005,7 +1005,9 @@ def forward( inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training) return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache) and not self.training: # kept for BC (non `Cache` `past_key_values` inputs) + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( diff --git a/src/transformers/models/gemma/diff_gemma.py b/src/transformers/models/gemma/diff_gemma.py index d2a653120965..4e2ea82950bf 100644 --- a/src/transformers/models/gemma/diff_gemma.py +++ b/src/transformers/models/gemma/diff_gemma.py @@ -474,7 +474,9 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) return_legacy_cache = False # noqa: F841 - if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True # noqa: F841 past_key_values = DynamicCache.from_legacy_cache(past_key_values) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 5bc1af3e7ec7..ae27fe98512f 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -770,7 +770,9 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) return_legacy_cache = False # noqa: F841 - if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True # noqa: F841 past_key_values = DynamicCache.from_legacy_cache(past_key_values) @@ -795,7 +797,9 @@ def forward( # See https://github.com/huggingface/transformers/pull/29402 normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) hidden_states = hidden_states * normalizer - if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once( diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index fa15393a40a5..583751520183 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -978,7 +978,9 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 75b674bf633f..16e8711188dd 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -811,7 +811,9 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache) and not self.training: # kept for BC (non `Cache` `past_key_values` inputs) + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) logger.warning_once(