-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove ambiguous padding_mask
and instead use a 2D->4D Attn Mask Mapper
#26792
Changes from 2 commits
3b5dfb7
5032628
0bfbc1a
d173ce3
1d230de
e542a6d
0256cd5
7e50ec0
30fc4c3
2301d6b
4cfb7cb
bdb39ae
67bd54e
2c42f7d
4387ab8
068ed57
d18268a
4a99e43
2fe66a0
d03d8a1
9e553ac
431b3a8
f8b2e4e
0338ffe
eb95315
ae3eb2e
b438dc8
5339988
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,6 +41,49 @@ | |
from .configuration_llama import LlamaConfig | ||
|
||
|
||
class AttentionMask2DTo4D: | ||
def __init__(self, is_causal: bool): | ||
self.is_causal = is_causal | ||
self.cached_2d_tensor = None | ||
self.cached_4d_tensor = None | ||
|
||
def __call__(self, attention_mask_2d: torch.Tensor, input_shape, past_key_values_length, dtype): | ||
""" | ||
Multiplies the given tensor x by -10,000. | ||
If the cached tensor does not exist or has a different size, a new one is allocated. | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
if self.cached_2d_tensor is None or (attention_mask_2d != self.cached_2d_tensor).any(): | ||
self.cached_2d_tensor = attention_mask_2d | ||
self.cached_4d_tensor = self._prepare_decoder_attention_mask(attention_mask_2d, input_shape, past_key_values_length, dtype) | ||
|
||
return self.cached_4d_tensor | ||
|
||
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask | ||
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, past_key_values_length, dtype): | ||
# create causal mask | ||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | ||
combined_attention_mask = None | ||
if input_shape[-1] > 1: | ||
combined_attention_mask = _make_causal_mask( | ||
input_shape, | ||
dtype, | ||
device=attention_mask.device, | ||
past_key_values_length=past_key_values_length, | ||
) | ||
|
||
if attention_mask is not None: | ||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | ||
expanded_attn_mask = _expand_mask(attention_mask, dtype, tgt_len=input_shape[-1]).to( | ||
attention_mask.device | ||
) | ||
combined_attention_mask = ( | ||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask | ||
) | ||
|
||
return combined_attention_mask | ||
|
||
|
||
|
||
if is_flash_attn_2_available(): | ||
from flash_attn import flash_attn_func, flash_attn_varlen_func | ||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa | ||
|
@@ -262,7 +305,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | |
class LlamaAttention(nn.Module): | ||
"""Multi-headed attention from 'Attention Is All You Need' paper""" | ||
|
||
def __init__(self, config: LlamaConfig): | ||
def __init__(self, config: LlamaConfig, mask_converter=None): | ||
super().__init__() | ||
self.config = config | ||
self.hidden_size = config.hidden_size | ||
|
@@ -272,6 +315,7 @@ def __init__(self, config: LlamaConfig): | |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads | ||
self.max_position_embeddings = config.max_position_embeddings | ||
self.rope_theta = config.rope_theta | ||
self.mask_converter = mask_converter | ||
|
||
if (self.head_dim * self.num_heads) != self.hidden_size: | ||
raise ValueError( | ||
|
@@ -376,6 +420,8 @@ def forward( | |
f" {attn_weights.size()}" | ||
) | ||
|
||
# convert 2d -> 4d. Re-use cached mask if available | ||
attention_mask = self.attn_mask_converter(attention_mask) | ||
if attention_mask is not None: | ||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): | ||
raise ValueError( | ||
|
@@ -420,12 +466,11 @@ class LlamaFlashAttention2(LlamaAttention): | |
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
past_key_value: Optional[Tuple[torch.Tensor]] = None, | ||
output_attentions: bool = False, | ||
use_cache: bool = False, | ||
padding_mask: Optional[torch.LongTensor] = None, | ||
attention_mask: Optional[torch.LongTensor] = None, | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||
# LlamaFlashAttention2 attention does not support output_attentions | ||
output_attentions = False | ||
|
@@ -485,7 +530,7 @@ def forward( | |
value_states = value_states.to(torch.float16) | ||
|
||
attn_output = self._flash_attention_forward( | ||
query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate | ||
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate | ||
) | ||
|
||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() | ||
|
@@ -538,7 +583,7 @@ def _flash_attention_forward( | |
max_seqlen_k=max_seqlen_in_batch_k, | ||
dropout_p=dropout, | ||
softmax_scale=softmax_scale, | ||
causal=True, | ||
causal=self.mask_converter.is_causal, | ||
) | ||
|
||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) | ||
|
@@ -589,13 +634,13 @@ def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_l | |
|
||
|
||
class LlamaDecoderLayer(nn.Module): | ||
def __init__(self, config: LlamaConfig): | ||
def __init__(self, config: LlamaConfig, mask_converter=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we support passing the mask converter here but not in the parent classes it's kind of pointless no?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry I don't fully understand this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To begin with I would not make the "attention cache" a class that the user plays around with, but instead use it as an internal convenience class that doesn't sacrifice speed but helps readability. Since the same instance of the class needs to be shared among the different layers, we need to instantiate it at a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see sorry realised that if you want to share the same cached mask you gotta pass it, ignore my comment |
||
super().__init__() | ||
self.hidden_size = config.hidden_size | ||
self.self_attn = ( | ||
LlamaAttention(config=config) | ||
LlamaAttention(config=config, mask_converter=mask_converter) | ||
if not getattr(config, "_flash_attn_2_enabled", False) | ||
else LlamaFlashAttention2(config=config) | ||
else LlamaFlashAttention2(config=config, mask_converter=mask_converter) | ||
) | ||
self.mlp = LlamaMLP(config) | ||
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
|
@@ -609,7 +654,6 @@ def forward( | |
past_key_value: Optional[Tuple[torch.Tensor]] = None, | ||
output_attentions: Optional[bool] = False, | ||
use_cache: Optional[bool] = False, | ||
padding_mask: Optional[torch.LongTensor] = None, | ||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: | ||
""" | ||
Args: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @patrickvonplaten Make sure to change the docstring line 728 (of this branch):
|
||
|
@@ -637,7 +681,6 @@ def forward( | |
past_key_value=past_key_value, | ||
output_attentions=output_attentions, | ||
use_cache=use_cache, | ||
padding_mask=padding_mask, | ||
) | ||
hidden_states = residual + hidden_states | ||
|
||
|
@@ -784,8 +827,9 @@ def __init__(self, config: LlamaConfig): | |
self.padding_idx = config.pad_token_id | ||
self.vocab_size = config.vocab_size | ||
|
||
attn_mask_converter = AttentionMask2DTo4D(is_causal=True) | ||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) | ||
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) | ||
self.layers = nn.ModuleList([LlamaDecoderLayer(config, attn_mask_converter) for _ in range(config.num_hidden_layers)]) | ||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
|
||
self.gradient_checkpointing = False | ||
|
@@ -798,30 +842,6 @@ def get_input_embeddings(self): | |
def set_input_embeddings(self, value): | ||
self.embed_tokens = value | ||
|
||
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask | ||
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): | ||
# create causal mask | ||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | ||
combined_attention_mask = None | ||
if input_shape[-1] > 1: | ||
combined_attention_mask = _make_causal_mask( | ||
input_shape, | ||
inputs_embeds.dtype, | ||
device=inputs_embeds.device, | ||
past_key_values_length=past_key_values_length, | ||
) | ||
|
||
if attention_mask is not None: | ||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | ||
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( | ||
inputs_embeds.device | ||
) | ||
combined_attention_mask = ( | ||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask | ||
) | ||
|
||
return combined_attention_mask | ||
|
||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) | ||
def forward( | ||
self, | ||
|
@@ -870,17 +890,6 @@ def forward( | |
if inputs_embeds is None: | ||
inputs_embeds = self.embed_tokens(input_ids) | ||
# embed positions | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if attention_mask is None: | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
attention_mask = torch.ones( | ||
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device | ||
) | ||
padding_mask = None | ||
else: | ||
if 0 in attention_mask: | ||
padding_mask = attention_mask | ||
else: | ||
padding_mask = None | ||
|
||
attention_mask = self._prepare_decoder_attention_mask( | ||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length | ||
) | ||
patrickvonplaten marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
@@ -910,7 +919,7 @@ def forward( | |
def create_custom_forward(module): | ||
def custom_forward(*inputs): | ||
# None for past_key_value | ||
return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask) | ||
return module(*inputs, past_key_value, output_attentions) | ||
|
||
return custom_forward | ||
|
||
|
@@ -925,7 +934,6 @@ def custom_forward(*inputs): | |
past_key_value=past_key_value, | ||
output_attentions=output_attentions, | ||
use_cache=use_cache, | ||
padding_mask=padding_mask, | ||
) | ||
|
||
hidden_states = layer_outputs[0] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we plan to move this to the modelling utils or is this gonna be here for all models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either #Copied from or we move it to a utils file. Both would work for me