Skip to content

Commit

Permalink
Addressed comments
Browse files Browse the repository at this point in the history
  • Loading branch information
EduardoPach committed Mar 26, 2024
1 parent 393d17f commit 94c2fe8
Showing 1 changed file with 6 additions and 236 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from typing import Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.cuda.amp import autocast
Expand All @@ -33,18 +32,12 @@
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from .configuration_decision_transformer import DecisionTransformerConfig


if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "edbeeching/decision-transformer-gym-hopper-medium"
Expand All @@ -54,19 +47,6 @@
from ..deprecated._archive_maps import DECISION_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402


# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)


# Copied from transformers.models.gpt2.modeling_gpt2.load_tf_weights_in_gpt2
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
"""Load tf checkpoints in a pytorch model"""
Expand Down Expand Up @@ -347,211 +327,6 @@ def forward(
return outputs # a, present, (attentions)


# Copied from transformers.models.gpt2.modeling_gpt2.GPT2FlashAttention2 with GPT2->DecisionTransformerGPT2
class DecisionTransformerGPT2FlashAttention2(DecisionTransformerGPT2Attention):
"""
DecisionTransformerGPT2 flash attention module. This module inherits from `DecisionTransformerGPT2Attention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
bsz, _, _ = hidden_states.size()
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `DecisionTransformerGPT2Attention(..., is_cross_attention=True)`."
)

query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)

present = None
if use_cache is True:
present = (key, value)

query_length = query.shape[2]
tgt_len = key.shape[2]

# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim)
key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)

attn_dropout = self.attn_dropout.p if self.training else 0.0

# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)

if query.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.c_proj.weight.dtype

logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)

query = query.to(target_dtype)
key = key.to(target_dtype)
value = value.to(target_dtype)

attn_output = self._flash_attention_forward(
query, key, value, attention_mask, query_length, dropout=attn_dropout
)

attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim)
attn_output = self.c_proj(attn_weights_reshaped)
attn_output = self.resid_dropout(attn_output)

outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights_reshaped,)

return outputs

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1

# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)

cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)

attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)

return attn_output

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape

key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)

return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)


# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP with GPT2->DecisionTransformerGPT2
class DecisionTransformerGPT2MLP(nn.Module):
def __init__(self, intermediate_size, config):
Expand All @@ -570,26 +345,22 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl
return hidden_states


DECISION_TRANSFORMER_GPT2_ATTENTION_CLASSES = {
"eager": DecisionTransformerGPT2Attention,
"flash_attention_2": DecisionTransformerGPT2FlashAttention2,
}


# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2, DecisionTransformerGPT2_ATTENTION_CLASSES->DECISION_TRANSFORMER_GPT2_ATTENTION_CLASSES
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2
class DecisionTransformerGPT2Block(nn.Module):
# Ignore copy
def __init__(self, config, layer_idx=None):
super().__init__()
hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
attention_class = DECISION_TRANSFORMER_GPT2_ATTENTION_CLASSES[config._attn_implementation]

self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = attention_class(config=config, layer_idx=layer_idx)
self.attn = DecisionTransformerGPT2Attention(config, layer_idx=layer_idx)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)

if config.add_cross_attention:
self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx)
self.crossattention = DecisionTransformerGPT2Attention(
config, is_cross_attention=True, layer_idx=layer_idx
)
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)

self.mlp = DecisionTransformerGPT2MLP(inner_dim, config)
Expand Down Expand Up @@ -666,7 +437,6 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_gpt2
base_model_prefix = "transformer"
is_parallelizable = True
_supports_flash_attn_2 = False
supports_gradient_checkpointing = True

def __init__(self, *inputs, **kwargs):
Expand Down

0 comments on commit 94c2fe8

Please sign in to comment.