Skip to content

Commit

Permalink
fix use cache (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada authored Sep 26, 2023
1 parent 7800457 commit 689f599
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/transformers/models/opt/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def forward(
# for the decoder
is_cross_attention = key_value_states is not None

bsz, tgt_len, _ = hidden_states.size()
bsz, _, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states)
Expand Down Expand Up @@ -351,13 +351,15 @@ def forward(
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)

query_length = query_states.shape[1]
tgt_len = key_states.shape[-2]

# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim)
query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim)
key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)

_, query_length, _, _ = query_states.shape

attn_dropout = self.dropout if self.training else 0.0

Expand Down

0 comments on commit 689f599

Please sign in to comment.