diff --git a/longformer/longformer.py b/longformer/longformer.py index 277aab1..53bb0ec 100644 --- a/longformer/longformer.py +++ b/longformer/longformer.py @@ -280,20 +280,20 @@ def forward( if getattr(self, 'global_tokens', 0) > 0: assert not self.use_global_proj # TODO: support the use of global projections # hidden_states shape: seqlen x batch x dim - q_g = self.query_global(hidden_states[:self.global_tokens]) + selected_q_g = self.query_global(hidden_states[:self.global_tokens]) k_g = self.key_global(hidden_states) v_g = self.value_global(hidden_states) - q_g = q_g.contiguous().view(self.global_tokens, bsz, self.num_heads, self.head_dim).transpose(0, 1) + selected_q_g = selected_q_g.contiguous().view(self.global_tokens, bsz, self.num_heads, self.head_dim).transpose(0, 1) k_g = k_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) v_g = v_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) # attn_weights: batch x source-tokens x heads x target-tokens - attn_weights = torch.einsum('blhd,bshd->blhs', (q_g, k_g)) + attn_weights = torch.einsum('blhd,bshd->blhs', (selected_q_g, k_g)) attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) selected_attn = torch.matmul(attn_probs.transpose(1, 2), v_g.transpose(1, 2)) # .view throws error (view size is not compatible with input tensor's size and stride) - attn[:self.global_tokens] = selected_attn.permute(2, 0, 1, 3).reshape(self.global_tokens, bsz, -1) + attn[:self.global_tokens] = selected_attn.permute(2, 0, 1, 3).contiguous().view(self.global_tokens, bsz, -1) context_layer = attn.transpose(0, 1) if output_attentions: