diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index 0c61848924a4..1b0ad5ad92fe 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -639,7 +639,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/examples/modular-transformers/modeling_multimodal1.py b/examples/modular-transformers/modeling_multimodal1.py index 45b10a5b206a..ec54af22186e 100644 --- a/examples/modular-transformers/modeling_multimodal1.py +++ b/examples/modular-transformers/modeling_multimodal1.py @@ -639,7 +639,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index ae71d724c25a..86669310c4f8 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -644,7 +644,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index e44c4bde1987..454860458636 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -561,7 +561,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 0b330b4aeeda..faa117fb374a 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -1054,7 +1054,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index 30017181738e..7e0b7217d734 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -1352,7 +1352,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index 7fb35f48fb3b..c890f311e754 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -1099,7 +1099,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 9d7325c502d6..8319ac05bc2b 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -785,7 +785,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index edbac91bb060..93145500a9f1 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1430,7 +1430,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 44cc2a3357c6..a0c70f58cf97 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -628,7 +628,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 0d0d6ccf9e85..da21f31209de 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -703,7 +703,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index a2373d345412..a24afcefb1a0 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1156,7 +1156,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index fc1e38549b7e..ef5a1f05d2c7 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -943,7 +943,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 722d9078d28a..f959d6fb0bae 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -1527,7 +1527,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index c0fad1ab66d5..a4a75f751232 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1087,7 +1087,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 66e975edaa53..6b5bf7a1de48 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -675,7 +675,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 3e5107c561df..dbe077fd3dee 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -685,7 +685,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index b4237370f1c3..3c4ddf7d0c20 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -927,7 +927,7 @@ def forward( # [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose. self_attention_mask = self_attention_mask.transpose(1, 2) - if query_length > 1 and attention_mask is not None and attention_mask.device.type == "cuda": + if query_length > 1 and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu"]: # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 self_attention_mask = AttentionMaskConverter._unmask_unattended( diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 3d30c9260c60..c30f178d7576 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -837,7 +837,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index e4814ce4e7c7..0fd43e9f31a2 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -945,7 +945,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 6a9ae6b50f90..09c207c225c7 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -708,7 +708,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 804218d588f9..1deda1631405 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -936,7 +936,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 3c887d3a1b91..49d5abe4ece0 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -688,7 +688,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 77ab0cece3ea..77aacce9ae2f 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -1177,7 +1177,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 7eed89b4af33..1d51137f0caa 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -672,7 +672,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 2e502d02fdef..2aa8d7bf51b3 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1408,7 +1408,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index fd6b1bae31b1..5d34fca936ea 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1379,7 +1379,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 433ca61fabec..956ba4bf1b16 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -1167,7 +1167,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8cbb12628c0a..7131aa128125 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -674,7 +674,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index c85c282439fe..64984b9c8f4e 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -1645,7 +1645,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 8342f83cfc54..5a3ee9594cc1 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -1133,7 +1133,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 635cda9cc8f0..e75d19e42fe1 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -665,7 +665,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index 362233a21b70..fb4d788ff857 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -185,7 +185,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 8cf2d0e8fa8d..af89067ab02c 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -799,7 +799,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index b40c366a6d75..62073d81854a 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -1121,7 +1121,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index 50894e3cff44..19779f297adf 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -974,7 +974,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index a1c15b7a0b37..9534399837ae 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -1362,7 +1362,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when @@ -1674,7 +1674,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 5667d2635b7c..7be31c35c3f5 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1237,7 +1237,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 54f774f0b942..5d5a0526ba1b 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -923,7 +923,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 34e9f7259cdf..56190e7648de 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -650,7 +650,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index a6a19265015b..683b9758b8fc 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -651,7 +651,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 5c78138c1a03..56ed9d054789 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -1085,7 +1085,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 8336ab5a2cf5..e89b502a791e 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -724,7 +724,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index c81fbbd013a8..b8ac285a6298 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -648,7 +648,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index cf905cb62e90..b6e8c986e675 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -735,7 +735,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index b540dd18300e..4b3760edf2f8 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -1248,7 +1248,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 77ce68659a50..6f22682481ff 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -1632,7 +1632,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 29cf09e34cfd..43aba63c8a41 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -1045,7 +1045,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index f8be4e3740f4..c3e6edcac255 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -659,7 +659,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 0f61323f4030..e3c412968c58 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1134,7 +1134,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 4893143e2c16..d4250fd1aac4 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1234,7 +1234,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 7fc01e95e371..37a7666cf405 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -765,7 +765,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) - if attention_mask is not None and attention_mask.device.type == "cuda": + if attention_mask is not None and attention_mask.device.type in ["cuda", "xpu"]: # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 4cdab6dc4d2d..4f23770f50d2 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -979,7 +979,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 8620c7d69d56..c33d443f0691 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -660,7 +660,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 80d765ce2e7f..cd6dfdb52f73 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1181,7 +1181,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 99b397ff12d6..a1e33404f4b6 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1250,7 +1250,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 88fe82b640f8..af54dcdd87dc 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -1583,7 +1583,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 2731150fd0e8..00bee69bf907 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -894,7 +894,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 150d683e927c..fba4ae114149 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1420,7 +1420,7 @@ def _update_causal_mask( if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 761c799bdcdc..a490dc4461af 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -1177,7 +1177,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): if ( self.config._attn_implementation == "sdpa" and attention_mask is not None - and attention_mask.device.type == "cuda" + and attention_mask.device.type in ["cuda", "xpu"] ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.