Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove ambiguous padding_mask and instead use a 2D->4D Attn Mask Mapper #26792

Merged
merged 28 commits into from
Oct 23, 2023
Merged
Changes from 2 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
3b5dfb7
[Attn Mask Converter] refactor attn mask
patrickvonplaten Oct 13, 2023
5032628
up
patrickvonplaten Oct 13, 2023
0bfbc1a
Apply suggestions from code review
patrickvonplaten Oct 16, 2023
d173ce3
improve
patrickvonplaten Oct 16, 2023
1d230de
improve
patrickvonplaten Oct 16, 2023
e542a6d
rename
patrickvonplaten Oct 16, 2023
0256cd5
better cache
patrickvonplaten Oct 16, 2023
7e50ec0
renaming
patrickvonplaten Oct 16, 2023
30fc4c3
improve more
patrickvonplaten Oct 16, 2023
2301d6b
improve
patrickvonplaten Oct 16, 2023
4cfb7cb
fix bug
patrickvonplaten Oct 16, 2023
bdb39ae
finalize
patrickvonplaten Oct 16, 2023
67bd54e
Merge branch 'main' of https://github.com/huggingface/transformers in…
patrickvonplaten Oct 19, 2023
2c42f7d
Merge branch 'attn_mask_converter' of https://github.com/huggingface/…
patrickvonplaten Oct 19, 2023
4387ab8
make style & make fix-copies
patrickvonplaten Oct 19, 2023
068ed57
correct more
patrickvonplaten Oct 19, 2023
d18268a
start moving attention_mask
patrickvonplaten Oct 19, 2023
4a99e43
fix llama
patrickvonplaten Oct 19, 2023
2fe66a0
improve falcon
patrickvonplaten Oct 19, 2023
d03d8a1
up
patrickvonplaten Oct 19, 2023
9e553ac
improve more
patrickvonplaten Oct 19, 2023
431b3a8
improve more
patrickvonplaten Oct 19, 2023
f8b2e4e
Update src/transformers/models/owlv2/modeling_owlv2.py
patrickvonplaten Oct 19, 2023
0338ffe
make style
patrickvonplaten Oct 19, 2023
eb95315
Merge branch 'attn_mask_converter' of https://github.com/huggingface/…
patrickvonplaten Oct 19, 2023
ae3eb2e
make style
patrickvonplaten Oct 19, 2023
b438dc8
rename to converter
patrickvonplaten Oct 19, 2023
5339988
Apply suggestions from code review
patrickvonplaten Oct 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 56 additions & 48 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,49 @@
from .configuration_llama import LlamaConfig


class AttentionMask2DTo4D:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we plan to move this to the modelling utils or is this gonna be here for all models?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either #Copied from or we move it to a utils file. Both would work for me

def __init__(self, is_causal: bool):
self.is_causal = is_causal
self.cached_2d_tensor = None
self.cached_4d_tensor = None

def __call__(self, attention_mask_2d: torch.Tensor, input_shape, past_key_values_length, dtype):
"""
Multiplies the given tensor x by -10,000.
If the cached tensor does not exist or has a different size, a new one is allocated.
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
"""
if self.cached_2d_tensor is None or (attention_mask_2d != self.cached_2d_tensor).any():
self.cached_2d_tensor = attention_mask_2d
self.cached_4d_tensor = self._prepare_decoder_attention_mask(attention_mask_2d, input_shape, past_key_values_length, dtype)

return self.cached_4d_tensor

# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, past_key_values_length, dtype):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
dtype,
device=attention_mask.device,
past_key_values_length=past_key_values_length,
)

if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, dtype, tgt_len=input_shape[-1]).to(
attention_mask.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)

return combined_attention_mask



if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
Expand Down Expand Up @@ -262,7 +305,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(self, config: LlamaConfig):
def __init__(self, config: LlamaConfig, mask_converter=None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
Expand All @@ -272,6 +315,7 @@ def __init__(self, config: LlamaConfig):
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.mask_converter = mask_converter

if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
Expand Down Expand Up @@ -376,6 +420,8 @@ def forward(
f" {attn_weights.size()}"
)

# convert 2d -> 4d. Re-use cached mask if available
attention_mask = self.attn_mask_converter(attention_mask)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
Expand Down Expand Up @@ -420,12 +466,11 @@ class LlamaFlashAttention2(LlamaAttention):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LlamaFlashAttention2 attention does not support output_attentions
output_attentions = False
Expand Down Expand Up @@ -485,7 +530,7 @@ def forward(
value_states = value_states.to(torch.float16)

attn_output = self._flash_attention_forward(
query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
)

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
Expand Down Expand Up @@ -538,7 +583,7 @@ def _flash_attention_forward(
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=True,
causal=self.mask_converter.is_causal,
)

attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
Expand Down Expand Up @@ -589,13 +634,13 @@ def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_l


class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
def __init__(self, config: LlamaConfig, mask_converter=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we support passing the mask converter here but not in the parent classes it's kind of pointless no?
Wondering which one would be the best:

  1. Pass the mask converter class to all classes
  2. Only have it in the attention layer, controlled with a MASK_CONVERTER = {"default": AttentionMask2DTo4D} and just in the attention layer do self.mask_converter = MASK_CONVERTER[config.mask_converter] with the attribute added to the config common?
    (naming can be improve for sure!)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I don't fully understand this

Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Oct 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To begin with I would not make the "attention cache" a class that the user plays around with, but instead use it as an internal convenience class that doesn't sacrifice speed but helps readability.

Since the same instance of the class needs to be shared among the different layers, we need to instantiate it at a ...Model level and then let it trickle down to the respective attention classes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see sorry realised that if you want to share the same cached mask you gotta pass it, ignore my comment

super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = (
LlamaAttention(config=config)
LlamaAttention(config=config, mask_converter=mask_converter)
if not getattr(config, "_flash_attn_2_enabled", False)
else LlamaFlashAttention2(config=config)
else LlamaFlashAttention2(config=config, mask_converter=mask_converter)
)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand All @@ -609,7 +654,6 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten Make sure to change the docstring line 728 (of this branch):

attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.

Expand Down Expand Up @@ -637,7 +681,6 @@ def forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = residual + hidden_states

Expand Down Expand Up @@ -784,8 +827,9 @@ def __init__(self, config: LlamaConfig):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

attn_mask_converter = AttentionMask2DTo4D(is_causal=True)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layers = nn.ModuleList([LlamaDecoderLayer(config, attn_mask_converter) for _ in range(config.num_hidden_layers)])
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

self.gradient_checkpointing = False
Expand All @@ -798,30 +842,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.embed_tokens = value

# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)

if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)

return combined_attention_mask

@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
def forward(
self,
Expand Down Expand Up @@ -870,17 +890,6 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
if attention_mask is None:
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
padding_mask = None
else:
if 0 in attention_mask:
padding_mask = attention_mask
else:
padding_mask = None

attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -910,7 +919,7 @@ def forward(
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask)
return module(*inputs, past_key_value, output_attentions)

return custom_forward

Expand All @@ -925,7 +934,6 @@ def custom_forward(*inputs):
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)

hidden_states = layer_outputs[0]
Expand Down