Skip to content

Commit

Permalink
address pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
armancohan committed Nov 5, 2020
1 parent 23b98d4 commit 06187c8
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions longformer/longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,15 +200,8 @@ def forward(
# (bsz, seq_len, num_heads, extra attention count + 2*window+1)
attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1)
if getattr(self, 'global_tokens', 0) > 0:
q_g = self.query_global(hidden_states)
k_g = self.key_global(hidden_states)
v_g = self.value_global(hidden_states)
q_g = q_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
k_g = k_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
v_g = v_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)

# (bsz, seq_len, num_heads, max_num_extra_indices_per_batch)
selected_attn_weights = torch.einsum('blhd,bshd->blhs', (q_g, k_g[:, :self.global_tokens]))
selected_attn_weights = torch.einsum('blhd,bshd->blhs', (q, k[:, :self.global_tokens]))
attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1)

attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
Expand All @@ -230,7 +223,7 @@ def forward(
attn_probs = attn_probs.narrow(-1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch).contiguous()
if getattr(self, 'global_tokens', 0) > 0:
selected_attn_probs = attn_probs.narrow(-1, 0, self.global_tokens)
selected_v = v_g[:, :self.global_tokens] # v_g has been only computed for global_tokens
selected_v = v[:, :self.global_tokens]
attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2).type_as(selected_attn_probs)).transpose(1, 2)
attn_probs = attn_probs.narrow(-1, self.global_tokens, attn_probs.size(-1) - self.global_tokens).contiguous()

Expand Down Expand Up @@ -286,7 +279,16 @@ def forward(

if getattr(self, 'global_tokens', 0) > 0:
assert not self.use_global_proj # TODO: support the use of global projections
attn_weights = torch.einsum('blhd,bshd->blhs', (q_g[:, :self.global_tokens], k_g))
# hidden_states shape: seqlen x batch x dim
q_g = self.query_global(hidden_states[:self.global_tokens])
k_g = self.key_global(hidden_states)
v_g = self.value_global(hidden_states)
q_g = q_g.contiguous().view(self.global_tokens, bsz, self.num_heads, self.head_dim).transpose(0, 1)
k_g = k_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
v_g = v_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)

# attn_weights: batch x source-tokens x heads x target-tokens
attn_weights = torch.einsum('blhd,bshd->blhs', (q_g, k_g))
attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
selected_attn = torch.matmul(attn_probs.transpose(1, 2), v_g.transpose(1, 2))
Expand Down

0 comments on commit 06187c8

Please sign in to comment.