Skip to content

Commit

Permalink
pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
armancohan committed Nov 5, 2020
1 parent 06187c8 commit e7e398c
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions longformer/longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,20 +280,20 @@ def forward(
if getattr(self, 'global_tokens', 0) > 0:
assert not self.use_global_proj # TODO: support the use of global projections
# hidden_states shape: seqlen x batch x dim
q_g = self.query_global(hidden_states[:self.global_tokens])
selected_q_g = self.query_global(hidden_states[:self.global_tokens])
k_g = self.key_global(hidden_states)
v_g = self.value_global(hidden_states)
q_g = q_g.contiguous().view(self.global_tokens, bsz, self.num_heads, self.head_dim).transpose(0, 1)
selected_q_g = selected_q_g.contiguous().view(self.global_tokens, bsz, self.num_heads, self.head_dim).transpose(0, 1)
k_g = k_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
v_g = v_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)

# attn_weights: batch x source-tokens x heads x target-tokens
attn_weights = torch.einsum('blhd,bshd->blhs', (q_g, k_g))
attn_weights = torch.einsum('blhd,bshd->blhs', (selected_q_g, k_g))
attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
selected_attn = torch.matmul(attn_probs.transpose(1, 2), v_g.transpose(1, 2))
# .view throws error (view size is not compatible with input tensor's size and stride)
attn[:self.global_tokens] = selected_attn.permute(2, 0, 1, 3).reshape(self.global_tokens, bsz, -1)
attn[:self.global_tokens] = selected_attn.permute(2, 0, 1, 3).contiguous().view(self.global_tokens, bsz, -1)

context_layer = attn.transpose(0, 1)
if output_attentions:
Expand Down

0 comments on commit e7e398c

Please sign in to comment.