diff --git a/longformer/longformer.py b/longformer/longformer.py index d6a5bb7..277aab1 100644 --- a/longformer/longformer.py +++ b/longformer/longformer.py @@ -200,15 +200,8 @@ def forward( # (bsz, seq_len, num_heads, extra attention count + 2*window+1) attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) if getattr(self, 'global_tokens', 0) > 0: - q_g = self.query_global(hidden_states) - k_g = self.key_global(hidden_states) - v_g = self.value_global(hidden_states) - q_g = q_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) - k_g = k_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) - v_g = v_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) - # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch) - selected_attn_weights = torch.einsum('blhd,bshd->blhs', (q_g, k_g[:, :self.global_tokens])) + selected_attn_weights = torch.einsum('blhd,bshd->blhs', (q, k[:, :self.global_tokens])) attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability @@ -230,7 +223,7 @@ def forward( attn_probs = attn_probs.narrow(-1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch).contiguous() if getattr(self, 'global_tokens', 0) > 0: selected_attn_probs = attn_probs.narrow(-1, 0, self.global_tokens) - selected_v = v_g[:, :self.global_tokens] # v_g has been only computed for global_tokens + selected_v = v[:, :self.global_tokens] attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2).type_as(selected_attn_probs)).transpose(1, 2) attn_probs = attn_probs.narrow(-1, self.global_tokens, attn_probs.size(-1) - self.global_tokens).contiguous() @@ -286,7 +279,16 @@ def forward( if getattr(self, 'global_tokens', 0) > 0: assert not self.use_global_proj # TODO: support the use of global projections - attn_weights = torch.einsum('blhd,bshd->blhs', (q_g[:, :self.global_tokens], k_g)) + # hidden_states shape: seqlen x batch x dim + q_g = self.query_global(hidden_states[:self.global_tokens]) + k_g = self.key_global(hidden_states) + v_g = self.value_global(hidden_states) + q_g = q_g.contiguous().view(self.global_tokens, bsz, self.num_heads, self.head_dim).transpose(0, 1) + k_g = k_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) + v_g = v_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) + + # attn_weights: batch x source-tokens x heads x target-tokens + attn_weights = torch.einsum('blhd,bshd->blhs', (q_g, k_g)) attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) selected_attn = torch.matmul(attn_probs.transpose(1, 2), v_g.transpose(1, 2))