Skip to content

Commit

Permalink
Merge pull request #1 from armancohan/global-attn
Browse files Browse the repository at this point in the history
Global attn
  • Loading branch information
armancohan authored Nov 9, 2020
2 parents 3a0b430 + e7e398c commit f379173
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 35 deletions.
22 changes: 14 additions & 8 deletions longformer/longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,7 @@ def forward(
if XLA_AVAILABLE:
attention_mask = None # disable global attention and masking for TPUs

if getattr(self, 'global_tokens', 0) > 0: # global tokens at the beginning of the sequence
assert attention_mask is None # no attention_mask is provided (no padding, no global attention selected tokens)

if attention_mask is not None:
if attention_mask is not None and getattr(self, 'global_tokens', 0) == 0:
attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1)
key_padding_mask = attention_mask < 0
extra_attention_mask = attention_mask > 0
Expand Down Expand Up @@ -147,7 +144,6 @@ def forward(
remove_from_windowed_attention_mask = None
extra_attention_mask = None
key_padding_mask = None

hidden_states = hidden_states.transpose(0, 1)
seq_len, bsz, embed_dim = hidden_states.size()
assert embed_dim == self.embed_dim
Expand Down Expand Up @@ -283,11 +279,21 @@ def forward(

if getattr(self, 'global_tokens', 0) > 0:
assert not self.use_global_proj # TODO: support the use of global projections
attn_weights = torch.einsum('blhd,bshd->blhs', (q[:, :self.global_tokens], k))
# hidden_states shape: seqlen x batch x dim
selected_q_g = self.query_global(hidden_states[:self.global_tokens])
k_g = self.key_global(hidden_states)
v_g = self.value_global(hidden_states)
selected_q_g = selected_q_g.contiguous().view(self.global_tokens, bsz, self.num_heads, self.head_dim).transpose(0, 1)
k_g = k_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
v_g = v_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)

# attn_weights: batch x source-tokens x heads x target-tokens
attn_weights = torch.einsum('blhd,bshd->blhs', (selected_q_g, k_g))
attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
selected_attn = torch.matmul(attn_probs.transpose(1, 2), v.transpose(1, 2))
attn[:self.global_tokens] = selected_attn.permute(2, 0, 1, 3).view(self.global_tokens, bsz, -1)
selected_attn = torch.matmul(attn_probs.transpose(1, 2), v_g.transpose(1, 2))
# .view throws error (view size is not compatible with input tensor's size and stride)
attn[:self.global_tokens] = selected_attn.permute(2, 0, 1, 3).contiguous().view(self.global_tokens, bsz, -1)

context_layer = attn.transpose(0, 1)
if output_attentions:
Expand Down
83 changes: 56 additions & 27 deletions scripts/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,53 @@

class CoolDataset(Dataset):

def __init__(self, seq_len, global_tokens, **kwargs):
self.seq_len = seq_len
self.global_tokens = global_tokens
super().__init__(**kwargs)

def __len__(self):
return 100
return 10000

def __getitem__(self, idx):
data = torch.tensor([1, 2, 3, 4] * 128 * 8)
mask = torch.tensor([1, 1, 1, 1] * 128 * 8)
data = torch.tensor([1, 2, 3, 4] * (self.seq_len // 4))
mask = torch.tensor([1, 1, 1, 1] * (self.seq_len // 4))
mask[:10] = 2
mask[-10:] = 0

return data, mask


class CoolSystem(pl.LightningModule):

def __init__(self):
def __init__(self, args, attention_mode='sliding_chunks', attention_window=256, batch_size=1, seq_len=4096, global_tokens=0):
super().__init__()
from longformer.longformer import LongformerForMaskedLM, LongformerConfig
self.config = LongformerConfig.from_pretrained('allenai/longformer-large-4096')
self.config.attention_mode = 'sliding_chunks'
# self.config.num_hidden_layers = 1
self.config.attention_dilation = [1] * self.config.num_hidden_layers
self.config.attention_window = [256] * self.config.num_hidden_layers
self.model = LongformerForMaskedLM(config=self.config)
for i, layer in enumerate(self.model.roberta.encoder.layer):
layer.attention.self.global_tokens = 0
layer.attention.self.attention_mode = 'sliding_chunks'
# self.model = self.model.roberta.encoder.layer[0].attention.self

# self.model = AutoModel.from_pretrained('allenai/longformer-base-4096')
# self.model = AutoModel.from_pretrained('roberta-base')
# self.config = LongformerConfig.from_pretrained('allenai/longformer-large-4096')
# self.config.attention_mode = attention_mode
# # self.config.num_hidden_layers = 1
# self.config.attention_dilation = [1] * self.config.num_hidden_layers
# self.config.attention_window = [attention_window] * self.config.num_hidden_layers
# self.model = LongformerForMaskedLM(config=self.config)
# for i, layer in enumerate(self.model.roberta.encoder.layer):
# layer.attention.self.global_tokens = 0
# layer.attention.self.attention_mode = attention_mode
# layer.attention.self.attention_window = attention_window
# # self.model = self.model.roberta.encoder.layer[0].attention.self

# # self.model = AutoModel.from_pretrained('allenai/longformer-base-4096')
# # self.model = AutoModel.from_pretrained('roberta-base')
self.batch_size = batch_size
self.count = 0
self.seq_len = seq_len
self.global_tokens = global_tokens

from longformer.longformer import LongformerForMaskedLM, LongformerConfig
self.config = LongformerConfig.from_pretrained(args.model)
self.config.attention_mode = args.attention_mode
self.model = LongformerForMaskedLM.from_pretrained(args.model, config=self.config)
for i, layer in enumerate(self.model.roberta.encoder.layer):
layer.attention.self.global_tokens = global_tokens
layer.attention.self.attention_window = attention_window

def to(self, *args, **kwargs):
param_count_before_moving_to_device = len(list(self.parameters()))
Expand All @@ -52,9 +68,9 @@ def to(self, *args, **kwargs):
print('==========', param_count_before_moving_to_device, param_count_after_moving_to_device)

def forward(self, x, y):
print(x.shape, self.model.roberta.encoder.layer[23].attention.self.attention_window,
self.model.roberta.encoder.layer[23].attention.self.global_tokens,
self.model.roberta.encoder.layer[23].attention.self.attention_mode)
print(x.shape, self.model.roberta.encoder.layer[11].attention.self.attention_window,
self.model.roberta.encoder.layer[11].attention.self.global_tokens,
self.model.roberta.encoder.layer[11].attention.self.attention_mode)
return self.model(x, attention_mask=y)
# return self.model(x[:, :, None].expand(1, 4096, 768).float())

Expand All @@ -66,20 +82,33 @@ def training_step(self, batch, batch_idx):
y_hat = self(x, y)
loss = y_hat[0].sum()
# xm.mark_step()
# import ipdb; ipdb.set_trace()
# exit()
return {'loss': loss}

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)

def train_dataloader(self):
loader = DataLoader(CoolDataset(), batch_size=1, num_workers=0)
loader = DataLoader(CoolDataset(seq_len=self.seq_len, global_tokens=self.global_tokens), batch_size=self.batch_size, num_workers=0)
return loader


if __name__ == '__main__':
model = CoolSystem()
trainer = pl.Trainer(num_tpu_cores=1, progress_bar_refresh_rate=10, max_epochs=10, num_sanity_val_steps=0, gpus=0,
checkpoint_callback=None)
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--attention_window', type=int, default=256)
parser.add_argument('--attention_mode', default='sliding_chunks')
parser.add_argument('--seq_len', type=int, default=4096)
parser.add_argument('--tpus', type=int, default=None)
parser.add_argument('--gpus', type=int, default=None)
parser.add_argument('--global_tokens', type=int, default=0)
parser.add_argument('--model')
args = parser.parse_args()
model = CoolSystem(args, attention_mode=args.attention_mode, attention_window=args.attention_window, batch_size=args.batch_size, seq_len=args.seq_len, global_tokens=args.global_tokens)
trainer = pl.Trainer(num_tpu_cores=args.tpus, progress_bar_refresh_rate=5, max_epochs=10, num_sanity_val_steps=0,
checkpoint_callback=None, gpus=args.gpus)
trainer.fit(model)

if __name__ == '__main__':
main()

0 comments on commit f379173

Please sign in to comment.