Skip to content

Commit

Permalink
Update Decision transformer attentions
Browse files Browse the repository at this point in the history
  • Loading branch information
EduardoPach committed Mar 4, 2024
1 parent a895172 commit 74fb9bd
Showing 1 changed file with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -572,18 +572,19 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl
return hidden_states


DECISIONTRANSFORMERGPT2_ATTENTION_CLASSES = {
DECISION_TRANSFORMER_GPT2_ATTENTION_CLASSES = {
"eager": DecisionTransformerGPT2Attention,
"flash_attention_2": DecisionTransformerGPT2FlashAttention2,
}


# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2, DecisionTransformerGPT2_ATTENTION_CLASSES->DECISIONTRANSFORMERGPT2_ATTENTION_CLASSES
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->DecisionTransformerGPT2, DecisionTransformerGPT2_ATTENTION_CLASSES->DECISION_TRANSFORMER_GPT2_ATTENTION_CLASSES
class DecisionTransformerGPT2Block(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
attention_class = DECISIONTRANSFORMERGPT2_ATTENTION_CLASSES[config._attn_implementation]
attention_class = DECISION_TRANSFORMER_GPT2_ATTENTION_CLASSES[config._attn_implementation]

self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = attention_class(config=config, layer_idx=layer_idx)
Expand Down

0 comments on commit 74fb9bd

Please sign in to comment.