From 5557b1234fefe680f6a72405bdd96b6da51cc8a3 Mon Sep 17 00:00:00 2001 From: Arman Cohan Date: Wed, 4 Nov 2020 15:41:02 -0800 Subject: [PATCH 1/4] global projections --- longformer/longformer.py | 26 ++++++++----- scripts/test_tpu.py | 83 +++++++++++++++++++++++++++------------- 2 files changed, 73 insertions(+), 36 deletions(-) diff --git a/longformer/longformer.py b/longformer/longformer.py index 00dd805..f9d1899 100644 --- a/longformer/longformer.py +++ b/longformer/longformer.py @@ -116,10 +116,11 @@ def forward( if XLA_AVAILABLE: attention_mask = None # disable global attention and masking for TPUs - if getattr(self, 'global_tokens', 0) > 0: # global tokens at the beginning of the sequence - assert attention_mask is None # no attention_mask is provided (no padding, no global attention selected tokens) + # if getattr(self, 'global_tokens', 0) > 0: # global tokens at the beginning of the sequence + # import ipdb; ipdb.set_trace() + # assert attention_mask is None # no attention_mask is provided (no padding, no global attention selected tokens) - if attention_mask is not None: + if attention_mask is not None and getattr(self, 'global_tokens', 0) == 0: attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1) key_padding_mask = attention_mask < 0 extra_attention_mask = attention_mask > 0 @@ -147,7 +148,6 @@ def forward( remove_from_windowed_attention_mask = None extra_attention_mask = None key_padding_mask = None - hidden_states = hidden_states.transpose(0, 1) seq_len, bsz, embed_dim = hidden_states.size() assert embed_dim == self.embed_dim @@ -204,8 +204,15 @@ def forward( # (bsz, seq_len, num_heads, extra attention count + 2*window+1) attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) if getattr(self, 'global_tokens', 0) > 0: + q_g = self.query_global(hidden_states) + k_g = self.key_global(hidden_states) + v_g = self.value_global(hidden_states) + q_g = q_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) + k_g = k_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) + v_g = v_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) + # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch) - selected_attn_weights = torch.einsum('blhd,bshd->blhs', (q, k[:, :self.global_tokens])) + selected_attn_weights = torch.einsum('blhd,bshd->blhs', (q_g, k_g[:, :self.global_tokens])) attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability @@ -227,7 +234,7 @@ def forward( attn_probs = attn_probs.narrow(-1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch).contiguous() if getattr(self, 'global_tokens', 0) > 0: selected_attn_probs = attn_probs.narrow(-1, 0, self.global_tokens) - selected_v = v[:, :self.global_tokens] + selected_v = v_g[:, :self.global_tokens] # v_g has been only computed for global_tokens attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2).type_as(selected_attn_probs)).transpose(1, 2) attn_probs = attn_probs.narrow(-1, self.global_tokens, attn_probs.size(-1) - self.global_tokens).contiguous() @@ -283,11 +290,12 @@ def forward( if getattr(self, 'global_tokens', 0) > 0: assert not self.use_global_proj # TODO: support the use of global projections - attn_weights = torch.einsum('blhd,bshd->blhs', (q[:, :self.global_tokens], k)) + attn_weights = torch.einsum('blhd,bshd->blhs', (q_g[:, :self.global_tokens], k_g)) attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) - selected_attn = torch.matmul(attn_probs.transpose(1, 2), v.transpose(1, 2)) - attn[:self.global_tokens] = selected_attn.permute(2, 0, 1, 3).view(self.global_tokens, bsz, -1) + selected_attn = torch.matmul(attn_probs.transpose(1, 2), v_g.transpose(1, 2)) + # .view throws error (view size is not compatible with input tensor's size and stride) + attn[:self.global_tokens] = selected_attn.permute(2, 0, 1, 3).reshape(self.global_tokens, bsz, -1) context_layer = attn.transpose(0, 1) if output_attentions: diff --git a/scripts/test_tpu.py b/scripts/test_tpu.py index 73c0a86..735a122 100644 --- a/scripts/test_tpu.py +++ b/scripts/test_tpu.py @@ -12,37 +12,53 @@ class CoolDataset(Dataset): + def __init__(self, seq_len, global_tokens, **kwargs): + self.seq_len = seq_len + self.global_tokens = global_tokens + super().__init__(**kwargs) + def __len__(self): - return 100 + return 10000 def __getitem__(self, idx): - data = torch.tensor([1, 2, 3, 4] * 128 * 8) - mask = torch.tensor([1, 1, 1, 1] * 128 * 8) + data = torch.tensor([1, 2, 3, 4] * (self.seq_len // 4)) + mask = torch.tensor([1, 1, 1, 1] * (self.seq_len // 4)) mask[:10] = 2 mask[-10:] = 0 - return data, mask class CoolSystem(pl.LightningModule): - def __init__(self): + def __init__(self, args, attention_mode='sliding_chunks', attention_window=256, batch_size=1, seq_len=4096, global_tokens=0): super().__init__() - from longformer.longformer import LongformerForMaskedLM, LongformerConfig - self.config = LongformerConfig.from_pretrained('allenai/longformer-large-4096') - self.config.attention_mode = 'sliding_chunks' - # self.config.num_hidden_layers = 1 - self.config.attention_dilation = [1] * self.config.num_hidden_layers - self.config.attention_window = [256] * self.config.num_hidden_layers - self.model = LongformerForMaskedLM(config=self.config) - for i, layer in enumerate(self.model.roberta.encoder.layer): - layer.attention.self.global_tokens = 0 - layer.attention.self.attention_mode = 'sliding_chunks' - # self.model = self.model.roberta.encoder.layer[0].attention.self - # self.model = AutoModel.from_pretrained('allenai/longformer-base-4096') - # self.model = AutoModel.from_pretrained('roberta-base') + # self.config = LongformerConfig.from_pretrained('allenai/longformer-large-4096') + # self.config.attention_mode = attention_mode + # # self.config.num_hidden_layers = 1 + # self.config.attention_dilation = [1] * self.config.num_hidden_layers + # self.config.attention_window = [attention_window] * self.config.num_hidden_layers + # self.model = LongformerForMaskedLM(config=self.config) + # for i, layer in enumerate(self.model.roberta.encoder.layer): + # layer.attention.self.global_tokens = 0 + # layer.attention.self.attention_mode = attention_mode + # layer.attention.self.attention_window = attention_window + # # self.model = self.model.roberta.encoder.layer[0].attention.self + + # # self.model = AutoModel.from_pretrained('allenai/longformer-base-4096') + # # self.model = AutoModel.from_pretrained('roberta-base') + self.batch_size = batch_size self.count = 0 + self.seq_len = seq_len + self.global_tokens = global_tokens + + from longformer.longformer import LongformerForMaskedLM, LongformerConfig + self.config = LongformerConfig.from_pretrained(args.model) + self.config.attention_mode = args.attention_mode + self.model = LongformerForMaskedLM.from_pretrained(args.model, config=self.config) + for i, layer in enumerate(self.model.roberta.encoder.layer): + layer.attention.self.global_tokens = global_tokens + layer.attention.self.attention_window = attention_window def to(self, *args, **kwargs): param_count_before_moving_to_device = len(list(self.parameters())) @@ -52,9 +68,9 @@ def to(self, *args, **kwargs): print('==========', param_count_before_moving_to_device, param_count_after_moving_to_device) def forward(self, x, y): - print(x.shape, self.model.roberta.encoder.layer[23].attention.self.attention_window, - self.model.roberta.encoder.layer[23].attention.self.global_tokens, - self.model.roberta.encoder.layer[23].attention.self.attention_mode) + print(x.shape, self.model.roberta.encoder.layer[11].attention.self.attention_window, + self.model.roberta.encoder.layer[11].attention.self.global_tokens, + self.model.roberta.encoder.layer[11].attention.self.attention_mode) return self.model(x, attention_mask=y) # return self.model(x[:, :, None].expand(1, 4096, 768).float()) @@ -66,7 +82,6 @@ def training_step(self, batch, batch_idx): y_hat = self(x, y) loss = y_hat[0].sum() # xm.mark_step() - # import ipdb; ipdb.set_trace() # exit() return {'loss': loss} @@ -74,12 +89,26 @@ def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.001) def train_dataloader(self): - loader = DataLoader(CoolDataset(), batch_size=1, num_workers=0) + loader = DataLoader(CoolDataset(seq_len=self.seq_len, global_tokens=self.global_tokens), batch_size=self.batch_size, num_workers=0) return loader -if __name__ == '__main__': - model = CoolSystem() - trainer = pl.Trainer(num_tpu_cores=1, progress_bar_refresh_rate=10, max_epochs=10, num_sanity_val_steps=0, gpus=0, - checkpoint_callback=None) +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--batch_size', type=int, default=1) + parser.add_argument('--attention_window', type=int, default=256) + parser.add_argument('--attention_mode', default='sliding_chunks') + parser.add_argument('--seq_len', type=int, default=4096) + parser.add_argument('--tpus', type=int, default=None) + parser.add_argument('--gpus', type=int, default=None) + parser.add_argument('--global_tokens', type=int, default=0) + parser.add_argument('--model') + args = parser.parse_args() + model = CoolSystem(args, attention_mode=args.attention_mode, attention_window=args.attention_window, batch_size=args.batch_size, seq_len=args.seq_len, global_tokens=args.global_tokens) + trainer = pl.Trainer(num_tpu_cores=args.tpus, progress_bar_refresh_rate=5, max_epochs=10, num_sanity_val_steps=0, + checkpoint_callback=None, gpus=args.gpus) trainer.fit(model) + +if __name__ == '__main__': + main() \ No newline at end of file From 23b98d47c8cee13c47a10d0b23453605c5ec362f Mon Sep 17 00:00:00 2001 From: Arman Cohan Date: Wed, 4 Nov 2020 15:43:54 -0800 Subject: [PATCH 2/4] cleanup --- longformer/longformer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/longformer/longformer.py b/longformer/longformer.py index f9d1899..d6a5bb7 100644 --- a/longformer/longformer.py +++ b/longformer/longformer.py @@ -116,10 +116,6 @@ def forward( if XLA_AVAILABLE: attention_mask = None # disable global attention and masking for TPUs - # if getattr(self, 'global_tokens', 0) > 0: # global tokens at the beginning of the sequence - # import ipdb; ipdb.set_trace() - # assert attention_mask is None # no attention_mask is provided (no padding, no global attention selected tokens) - if attention_mask is not None and getattr(self, 'global_tokens', 0) == 0: attention_mask = attention_mask.squeeze(dim=2).squeeze(dim=1) key_padding_mask = attention_mask < 0 From 06187c8e27ad810564e838bba678bba8d168b132 Mon Sep 17 00:00:00 2001 From: Arman Cohan Date: Wed, 4 Nov 2020 20:22:01 -0800 Subject: [PATCH 3/4] address pr comments --- longformer/longformer.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/longformer/longformer.py b/longformer/longformer.py index d6a5bb7..277aab1 100644 --- a/longformer/longformer.py +++ b/longformer/longformer.py @@ -200,15 +200,8 @@ def forward( # (bsz, seq_len, num_heads, extra attention count + 2*window+1) attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) if getattr(self, 'global_tokens', 0) > 0: - q_g = self.query_global(hidden_states) - k_g = self.key_global(hidden_states) - v_g = self.value_global(hidden_states) - q_g = q_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) - k_g = k_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) - v_g = v_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) - # (bsz, seq_len, num_heads, max_num_extra_indices_per_batch) - selected_attn_weights = torch.einsum('blhd,bshd->blhs', (q_g, k_g[:, :self.global_tokens])) + selected_attn_weights = torch.einsum('blhd,bshd->blhs', (q, k[:, :self.global_tokens])) attn_weights = torch.cat((selected_attn_weights, attn_weights), dim=-1) attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability @@ -230,7 +223,7 @@ def forward( attn_probs = attn_probs.narrow(-1, max_num_extra_indices_per_batch, attn_probs.size(-1) - max_num_extra_indices_per_batch).contiguous() if getattr(self, 'global_tokens', 0) > 0: selected_attn_probs = attn_probs.narrow(-1, 0, self.global_tokens) - selected_v = v_g[:, :self.global_tokens] # v_g has been only computed for global_tokens + selected_v = v[:, :self.global_tokens] attn = torch.matmul(selected_attn_probs.transpose(1, 2), selected_v.transpose(1, 2).type_as(selected_attn_probs)).transpose(1, 2) attn_probs = attn_probs.narrow(-1, self.global_tokens, attn_probs.size(-1) - self.global_tokens).contiguous() @@ -286,7 +279,16 @@ def forward( if getattr(self, 'global_tokens', 0) > 0: assert not self.use_global_proj # TODO: support the use of global projections - attn_weights = torch.einsum('blhd,bshd->blhs', (q_g[:, :self.global_tokens], k_g)) + # hidden_states shape: seqlen x batch x dim + q_g = self.query_global(hidden_states[:self.global_tokens]) + k_g = self.key_global(hidden_states) + v_g = self.value_global(hidden_states) + q_g = q_g.contiguous().view(self.global_tokens, bsz, self.num_heads, self.head_dim).transpose(0, 1) + k_g = k_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) + v_g = v_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) + + # attn_weights: batch x source-tokens x heads x target-tokens + attn_weights = torch.einsum('blhd,bshd->blhs', (q_g, k_g)) attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) selected_attn = torch.matmul(attn_probs.transpose(1, 2), v_g.transpose(1, 2)) From e7e398c80b16cc44fbe031e1401f7bba49ea61bc Mon Sep 17 00:00:00 2001 From: Arman Cohan Date: Thu, 5 Nov 2020 10:11:46 -0800 Subject: [PATCH 4/4] pr comments --- longformer/longformer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/longformer/longformer.py b/longformer/longformer.py index 277aab1..53bb0ec 100644 --- a/longformer/longformer.py +++ b/longformer/longformer.py @@ -280,20 +280,20 @@ def forward( if getattr(self, 'global_tokens', 0) > 0: assert not self.use_global_proj # TODO: support the use of global projections # hidden_states shape: seqlen x batch x dim - q_g = self.query_global(hidden_states[:self.global_tokens]) + selected_q_g = self.query_global(hidden_states[:self.global_tokens]) k_g = self.key_global(hidden_states) v_g = self.value_global(hidden_states) - q_g = q_g.contiguous().view(self.global_tokens, bsz, self.num_heads, self.head_dim).transpose(0, 1) + selected_q_g = selected_q_g.contiguous().view(self.global_tokens, bsz, self.num_heads, self.head_dim).transpose(0, 1) k_g = k_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) v_g = v_g.contiguous().view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1) # attn_weights: batch x source-tokens x heads x target-tokens - attn_weights = torch.einsum('blhd,bshd->blhs', (q_g, k_g)) + attn_weights = torch.einsum('blhd,bshd->blhs', (selected_q_g, k_g)) attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) # use fp32 for numerical stability attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) selected_attn = torch.matmul(attn_probs.transpose(1, 2), v_g.transpose(1, 2)) # .view throws error (view size is not compatible with input tensor's size and stride) - attn[:self.global_tokens] = selected_attn.permute(2, 0, 1, 3).reshape(self.global_tokens, bsz, -1) + attn[:self.global_tokens] = selected_attn.permute(2, 0, 1, 3).contiguous().view(self.global_tokens, bsz, -1) context_layer = attn.transpose(0, 1) if output_attentions: