diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 734f443e1fc9..f0964f940285 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -302,10 +302,22 @@ def _prepare_4d_causal_attention_mask( key_value_length = input_shape[-1] + past_key_values_length # 4d mask is passed through the layers - if attention_mask is not None: + if attention_mask is not None and len(attention_mask.shape) == 2: attention_mask = attn_mask_converter.to_4d( attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype ) + elif attention_mask is not None and len(attention_mask.shape) == 4: + expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) + if tuple(attention_mask.shape) != expected_shape: + raise ValueError( + f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." + ) + else: + # if the 4D mask has correct shape - invert it and fill with negative infinity + inverted_mask = 1.0 - attention_mask + attention_mask = inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min + ) else: attention_mask = attn_mask_converter.to_causal_4d( input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device @@ -340,7 +352,22 @@ def _prepare_4d_causal_attention_mask_for_sdpa( is_tracing = torch.jit.is_tracing() if attention_mask is not None: - if torch.all(attention_mask == 1): + # 4d mask is passed through + if len(attention_mask.shape) == 4: + expected_shape = (input_shape[0], 1, input_shape[1], key_value_length) + if tuple(attention_mask.shape) != expected_shape: + raise ValueError( + f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." + ) + else: + # if the 4D mask has correct shape - invert it and fill with negative infinity + inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype) + attention_mask = inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min + ) + return attention_mask + + elif torch.all(attention_mask == 1): if is_tracing: pass elif query_length == 1: diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index ddfaad5214dc..86c07e5b7273 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy +import gc import glob import json import os @@ -49,6 +50,7 @@ require_tf, require_torch, require_torch_accelerator, + require_torch_gpu, require_torch_multi_accelerator, require_usr_bin_time, slow, @@ -1850,3 +1852,134 @@ def test_not_available_sdpa(self): ) self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception)) + + +@slow +@require_torch_gpu +class Mask4DTestBase(unittest.TestCase): + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def get_test_data(self): + texts = ["the cat sat", "the cat had", "the cat is"] + encoded = [self.tokenizer.encode(t) for t in texts] + input_0 = torch.tensor(encoded, device=torch_device) + # tensor([[ 1, 278, 6635, 3290], + # [ 1, 278, 6635, 750], + # [ 1, 278, 6635, 338]], device='cuda:0') + + # Combining common prefix with the unique ending tokens: + input_1 = torch.cat([input_0[0][:-1], input_0[:, -1]]).unsqueeze(0) + # tensor([[ 1, 278, 6635, 3290, 750, 338]], device='cuda:0') + + # Creating a 4D mask where each of the last 3 tokens do not attend to each other. + mask_1 = torch.tensor( + [ + [ + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0], + [1, 1, 1, 0, 1, 0], + [1, 1, 1, 0, 0, 1], + ] + ] + ], + device="cuda:0", + dtype=torch.int64, + ) + + # Creating a position_ids tensor. note the repeating figures in the end. + position_ids_1 = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64) + + return input_0, input_1, mask_1, position_ids_1 + + +@slow +@require_torch_gpu +class Mask4DTestFP32(Mask4DTestBase): + def setUp(self): + model_name = "JackFram/llama-68m" # small Llama-like model from FlexFlow + model_dtype = torch.float32 + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype).to(torch_device) + + def test_attention(self): + """comparing outputs of attention layer""" + input_0, input_1, mask_1, position_ids_1 = self.get_test_data() + + hid_0 = self.model.model.embed_tokens(input_0) + outs_0 = self.model.model.layers[0].self_attn.forward(hid_0)[0] + # outs_0.shape == torch.Size([3, 4, 768]) + + hid_1 = self.model.model.embed_tokens(input_1) + outs_1 = self.model.model.layers[0].self_attn.forward( + hid_1, attention_mask=mask_1.bool(), position_ids=position_ids_1 + )[0] + # outs_1.shape == torch.Size([1, 6, 768]) + + outs_0_last_tokens = outs_0[:, -1, :] # last tokens in each batch line + outs_1_last_tokens = outs_1[0, -3:, :] # last three tokens + assert torch.allclose(outs_0_last_tokens, outs_1_last_tokens) + + def test_inner_model(self): + """comparing hidden outputs of whole inner model""" + input_0, input_1, mask_1, position_ids_1 = self.get_test_data() + + logits_0 = self.model.forward(input_0).logits + logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits + + logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line + logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens + torch.testing.assert_close( + logits_0_last_tokens, + logits_1_last_tokens, + ) + + def test_causal_model_logits(self): + """comparing logits outputs of whole inner model""" + input_0, input_1, mask_1, position_ids_1 = self.get_test_data() + + logits_0 = self.model.forward(input_0).logits + logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits + + logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line + logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens + torch.testing.assert_close( + logits_0_last_tokens, + logits_1_last_tokens, + ) + + +@slow +@require_torch_gpu +class Mask4DTestFP16(Mask4DTestBase): + test_attention = Mask4DTestFP32.test_attention + + def setUp(self): + model_name = "JackFram/llama-68m" # small Llama-like model from FlexFlow + model_dtype = torch.float16 + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype).to(torch_device) + + def test_causal_model_logits(self): + """comparing logits outputs of whole inner model""" + input_0, input_1, mask_1, position_ids_1 = self.get_test_data() + + logits_0 = self.model.forward(input_0).logits + logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits + + logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line + logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens + + indices_0 = logits_0_last_tokens.sort(descending=True).indices + indices_1 = logits_1_last_tokens.sort(descending=True).indices + + # checking logits, but note relaxed tolerances for FP16 + torch.testing.assert_close(logits_0_last_tokens, logits_1_last_tokens, atol=0.02, rtol=0.001) + + # checking tokens order for the top tokens + for token_ids_0, token_ids_1 in zip(indices_0, indices_1): + self.assertTrue(torch.equal(token_ids_0[:128], token_ids_1[:128]))