Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

T5 encoder decoder #249

Open
wants to merge 122 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
122 commits
Select commit Hold shift + click to select a range
1f55379
adding output_attentions arg
ibeltagy Jul 4, 2020
b98d191
adding gradient_checkpointing config
ibeltagy Jul 4, 2020
c10277c
convert bart to longformer_encoder_decoder + memory profiler
ibeltagy Jul 4, 2020
e29d7f5
reqs and init
ibeltagy Jul 4, 2020
dd0dc0d
fix req
ibeltagy Jul 4, 2020
54a1328
req
ibeltagy Jul 4, 2020
243cfe8
Update README.md
ibeltagy Jul 6, 2020
fbbc770
Update README.md
ibeltagy Jul 6, 2020
95296ad
pretraining script
ibeltagy Jul 16, 2020
325693e
wip
ibeltagy Jul 16, 2020
985acc9
wip
ibeltagy Jul 17, 2020
023dd78
wip
ibeltagy Jul 17, 2020
08230ac
wip
ibeltagy Jul 17, 2020
fb65d57
.
ibeltagy Jul 17, 2020
0e80cde
pad chunks or start next doc
ibeltagy Jul 17, 2020
6ca7d1b
todo
ibeltagy Jul 17, 2020
a2aa4f7
wip
ibeltagy Jul 17, 2020
62a69d5
wip
ibeltagy Jul 17, 2020
3e3a478
wip
ibeltagy Jul 18, 2020
3bc5354
wip
ibeltagy Jul 18, 2020
1a91024
wip
ibeltagy Jul 18, 2020
5fa21f2
wip
ibeltagy Jul 18, 2020
18eb003
wip
ibeltagy Jul 18, 2020
607e446
wip
ibeltagy Jul 18, 2020
d4659de
wip
ibeltagy Jul 18, 2020
c7c53cb
wip
ibeltagy Jul 19, 2020
0a07daf
wip
ibeltagy Jul 22, 2020
827576c
wip
ibeltagy Jul 22, 2020
5d0c8a2
wip
ibeltagy Jul 22, 2020
413258a
wip
ibeltagy Jul 22, 2020
9b8a7d6
Merge branch 'encoderdecoder' of github.com:allenai/longformer into e…
ibeltagy Jul 22, 2020
1a6498c
tpu
ibeltagy Jul 22, 2020
3e82548
wip
ibeltagy Jul 22, 2020
adadd42
wip
ibeltagy Jul 23, 2020
9e191a0
pretraining script
ibeltagy Jul 16, 2020
9d18808
wip
ibeltagy Jul 16, 2020
6e24cee
wip
ibeltagy Jul 17, 2020
a2ab9b3
wip
ibeltagy Jul 17, 2020
e3f4ba9
wip
ibeltagy Jul 17, 2020
f9e654b
.
ibeltagy Jul 17, 2020
9c2646d
pad chunks or start next doc
ibeltagy Jul 17, 2020
433a2e2
todo
ibeltagy Jul 17, 2020
ec47270
wip
ibeltagy Jul 17, 2020
77e105d
wip
ibeltagy Jul 17, 2020
af08b5a
wip
ibeltagy Jul 18, 2020
d105023
wip
ibeltagy Jul 18, 2020
1183999
wip
ibeltagy Jul 18, 2020
20e8208
wip
ibeltagy Jul 18, 2020
224824d
wip
ibeltagy Jul 18, 2020
4a12730
wip
ibeltagy Jul 18, 2020
c936d24
wip
ibeltagy Jul 18, 2020
510801b
wip
ibeltagy Jul 19, 2020
9184b71
wip
ibeltagy Jul 22, 2020
4ae991a
wip
ibeltagy Jul 22, 2020
aea2a98
tpu
ibeltagy Jul 22, 2020
69b717a
wip
ibeltagy Jul 22, 2020
5f641c0
wip
ibeltagy Jul 23, 2020
e3ddeca
wip
ibeltagy Jul 23, 2020
21c9e57
Merge branch 'trainer' of github.com:allenai/longformer into trainer
ibeltagy Jul 23, 2020
00ce1e9
wip
ibeltagy Jul 23, 2020
56b9c6a
wip
ibeltagy Jul 23, 2020
8fca187
wip
ibeltagy Jul 25, 2020
9dd76b7
wip
ibeltagy Jul 25, 2020
d40983a
wip
ibeltagy Jul 25, 2020
f0f6a30
wip
ibeltagy Jul 25, 2020
a6e37df
Merge branch 'trainer' of github.com:allenai/longformer into trainer
ibeltagy Jul 25, 2020
9eb6fdf
wip
ibeltagy Jul 25, 2020
14b6074
wip
ibeltagy Jul 25, 2020
5b97bd6
wip
ibeltagy Jul 25, 2020
71d7a9d
wip
ibeltagy Jul 25, 2020
97a126d
wip
ibeltagy Jul 25, 2020
c873da2
wip
ibeltagy Jul 25, 2020
d602869
faster gradnorm
ibeltagy Jul 28, 2020
ffd06dd
allow changing seqlen at runtime
ibeltagy Jul 28, 2020
129a3f9
log and resume data preprocessing
ibeltagy Jul 30, 2020
1c42f96
multiprocessed preprocessing
ibeltagy Jul 30, 2020
c20264e
wip
ibeltagy Aug 3, 2020
ff96351
Save this directory as a dataset and use it directly on a plain base …
meslater1030 Aug 3, 2020
0557e24
bug fix
ibeltagy Aug 6, 2020
6ae5051
fix a bug with the mapping from longformerselfattention to bartselfat…
ibeltagy Aug 7, 2020
a1de977
mem_profiler
ibeltagy Aug 7, 2020
1bf6c7c
extend encoder only
ibeltagy Aug 12, 2020
5b31f5e
upgrade triviaqa script to PLv0.8.5
ibeltagy Aug 12, 2020
405739e
add roberta baseline
ibeltagy Aug 13, 2020
162c22f
Merge branch 'mes/longformer-on-beaker-copy-all' into encoderdecoder
ibeltagy Aug 17, 2020
c132d4e
triviaqa seq2seq + fix bart-base bug
ibeltagy Aug 17, 2020
671aa72
Merge branch 'encoderdecoder' of github.com:allenai/longformer into e…
ibeltagy Aug 17, 2020
d1349e9
beaker
ibeltagy Aug 23, 2020
5a2b9da
sliding_chunks_no_overlap (#100)
ibeltagy Aug 25, 2020
b15607b
seq2seq
ibeltagy Aug 26, 2020
82741c3
wip
ibeltagy Aug 27, 2020
5c3a22a
wip
ibeltagy Aug 27, 2020
75aeb47
wip
ibeltagy Aug 28, 2020
391d8de
Update README.md
ibeltagy Sep 1, 2020
3aa2f67
Merge branch 'master' into encoderdecoder
ibeltagy Sep 2, 2020
bf9e58a
summarization
ibeltagy Sep 2, 2020
0158125
fix loading data
ibeltagy Sep 3, 2020
9fdef52
wip
ibeltagy Sep 3, 2020
cbb407d
wip
ibeltagy Sep 3, 2020
eb34cc0
grad_ckpt + reqs + long
ibeltagy Sep 3, 2020
42481fd
ignore empty answers
ibeltagy Sep 4, 2020
79b9b0d
Merge branch 'encoderdecoder' of github.com:allenai/longformer into e…
ibeltagy Sep 4, 2020
c6f2335
attention dropout
ibeltagy Sep 4, 2020
f5b9498
Merge branch 'encoderdecoder' of github.com:allenai/longformer into e…
ibeltagy Sep 4, 2020
f5a798d
model.generate takes a lot of memory. Set requires_grad=False
ibeltagy Sep 4, 2020
274d017
wip
ibeltagy Sep 5, 2020
5f765b9
wip
ibeltagy Sep 5, 2020
5784aee
attention_mode
ibeltagy Sep 6, 2020
b78384a
wip
ibeltagy Sep 7, 2020
48c0344
Merge branch 'encoderdecoder' of github.com:allenai/longformer into e…
ibeltagy Sep 7, 2020
327b729
pegasus bug
ibeltagy Sep 7, 2020
4944fb8
run on cpu
ibeltagy Sep 7, 2020
36252c0
readme
ibeltagy Sep 7, 2020
281999f
readme
ibeltagy Sep 7, 2020
dee3daf
adafactor and label smoothing
ibeltagy Sep 9, 2020
498ca04
add rougeLsum
ibeltagy Sep 13, 2020
d928dfb
Merge branch 'encoderdecoder' of github.com:allenai/longformer into e…
ibeltagy Sep 13, 2020
0f3875f
wip code for LongT5
AkshitaB Nov 12, 2020
adc92ca
naive code for smaller score matrix
AkshitaB Nov 23, 2020
02f9c3f
commenting attn_mask
AkshitaB Nov 23, 2020
de147de
fixing order and issue with inf
AkshitaB Nov 24, 2020
1ba5286
fixing compute_bias
AkshitaB Dec 8, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,24 @@
# <p align=center>`Longformer`</p>
`Longformer` is a BERT-like model for long documents.


**\*\*\*\*\* Work In Progress: LongformerEncoderDecoder \*\*\*\*\***

A `LongformerEncoderDecoder` model is now available. It is geared towards summarization where the input is long but the output is relatively shorter. The following code snippet loads a `LongformerEncoderDecoder` checkpointing started from `BART`. With gradient checkpointing, fp16, and 48GB gpu, the input length can be up to 16K tokens.
```
pip install git+https://github.com/allenai/longformer.git@encoderdecoder

# checkpoint-base: https://ai2-s2-research.s3-us-west-2.amazonaws.com/longformer/longformer-encdec-base-16384.tar.gz
# checkpoint-large: https://ai2-s2-research.s3-us-west-2.amazonaws.com/longformer/longformer-encdec-large-16384.tar.gz

from longformer import LongformerEncoderDecoderForConditionalGeneration
model = LongformerEncoderDecoderForConditionalGeneration.from_pretrained(downloaded_checkpoint, gradient_checkpointing=True)
```

- Check the script `scripts/summarization.py` for an example of how to use the model.

- Make sure to use the huggingface/transformers fork specified in `requirements.txt`.

**\*\*\*\*\* New July 23rd, 2020: Speed degradation \*\*\*\*\***

A significant speed degradation in the hugginface/transformers was recenlty discovered and fixed (check [this PR](https://github.com/huggingface/transformers/pull/5811) for details). To avoid this problem, either use the old [release v2.11.0](https://github.com/huggingface/transformers/tree/v2.11.0) but it doesn't support gradient checkpointing, or use the master branch. This problem should be fixed with the next hugginface/transformers release.
Expand Down
18 changes: 18 additions & 0 deletions experiment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
tasks:
- cluster: {{.Env.CLUSTER}}
spec:
# This is a python3.7/nvidia base image with basic libraries
image: im_j69gti4atcw9
resultPath: {{.Env.RESULT_PATH}}
args:
- /bin/bash
- -c
- "cd /longformer_on_beaker && pip install . && {{.Env.ARGS}}"
datasetMounts:
- datasetId: {{.Env.INPUT_DATASET_ID}}
containerPath: /data
- datasetId: {{.Env.SCRIPTS}}
containerPath: /longformer_on_beaker
requirements:
gpuCount: {{.Env.GPU_COUNT}}
cpu: {{.Env.CPU_COUNT}}
3 changes: 3 additions & 0 deletions longformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from longformer.longformer import Longformer, LongformerForMaskedLM, LongformerConfig
from longformer.longformer_encoder_decoder import LongformerEncoderDecoderConfig
from longformer.longformer_encoder_decoder import LongformerEncoderDecoderForConditionalGeneration
29 changes: 20 additions & 9 deletions longformer/longformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn.functional as F
from longformer.diagonaled_mm_tvm import diagonaled_mm as diagonaled_mm_tvm, mask_invalid_locations
from longformer.sliding_chunks import sliding_chunks_matmul_qk, sliding_chunks_matmul_pv
from longformer.sliding_chunks import sliding_chunks_no_overlap_matmul_qk, sliding_chunks_no_overlap_matmul_pv
from transformers.modeling_roberta import RobertaConfig, RobertaModel, RobertaForMaskedLM


Expand Down Expand Up @@ -48,7 +49,7 @@ def __init__(self, attention_window: List[int] = None, attention_dilation: List[
self.attention_dilation = attention_dilation
self.autoregressive = autoregressive
self.attention_mode = attention_mode
assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2']
assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2', 'sliding_chunks_no_overlap']


class LongformerSelfAttention(nn.Module):
Expand All @@ -58,7 +59,6 @@ def __init__(self, config, layer_id):
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.output_attentions = config.output_attentions
self.num_heads = config.num_attention_heads
self.head_dim = int(config.hidden_size / config.num_attention_heads)
self.embed_dim = config.hidden_size
Expand All @@ -80,8 +80,8 @@ def __init__(self, config, layer_id):
self.autoregressive = config.autoregressive
assert self.attention_window > 0
assert self.attention_dilation > 0
assert self.attention_mode in ['tvm', 'sliding_chunks']
if self.attention_mode == 'sliding_chunks':
assert self.attention_mode in ['tvm', 'sliding_chunks', 'sliding_chunks_no_overlap']
if self.attention_mode in ['sliding_chunks', 'sliding_chunks_no_overlap']:
assert not self.autoregressive # not supported
assert self.attention_dilation == 1 # dilation is not supported

Expand Down Expand Up @@ -147,8 +147,12 @@ def forward(
q = q.float().contiguous()
k = k.float().contiguous()
attn_weights = diagonaled_mm_tvm(q, k, self.attention_window, self.attention_dilation, False, 0, False)
else: # "sliding_chunks"
elif self.attention_mode == "sliding_chunks":
attn_weights = sliding_chunks_matmul_qk(q, k, self.attention_window, padding_value=0)
elif self.attention_mode == "sliding_chunks_no_overlap":
attn_weights = sliding_chunks_no_overlap_matmul_qk(q, k, self.attention_window, padding_value=0)
else:
raise False
mask_invalid_locations(attn_weights, self.attention_window, self.attention_dilation, False)
if remove_from_windowed_attention_mask is not None:
# This implementation is fast and takes very little memory because num_heads x hidden_size = 1
Expand All @@ -162,10 +166,14 @@ def forward(
# diagonal mask with zeros everywhere and -inf inplace of padding
if self.attention_mode == 'tvm':
d_mask = diagonaled_mm_tvm(ones, float_mask, self.attention_window, self.attention_dilation, False, 0, False)
else:
elif self.attention_mode == "sliding_chunks":
d_mask = sliding_chunks_matmul_qk(ones, float_mask, self.attention_window, padding_value=0)
elif self.attention_mode == "sliding_chunks_no_overlap":
d_mask = sliding_chunks_no_overlap_matmul_qk(ones, float_mask, self.attention_window, padding_value=0)

attn_weights += d_mask
assert list(attn_weights.size()) == [bsz, seq_len, self.num_heads, self.attention_window * 2 + 1]
assert list(attn_weights.size())[:3] == [bsz, seq_len, self.num_heads]
assert attn_weights.size(dim=3) in [self.attention_window * 2 + 1, self.attention_window * 3]

# the extra attention
if extra_attention_mask is not None:
Expand All @@ -182,7 +190,6 @@ def forward(
if key_padding_mask is not None:
# softmax sometimes inserts NaN if all positions are masked, replace them with 0
attn_weights_float = torch.masked_fill(attn_weights_float, key_padding_mask.unsqueeze(-1).unsqueeze(-1), 0.0)

attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
v = v.view(seq_len, bsz, self.num_heads, self.head_dim).transpose(0, 1)
Expand All @@ -199,8 +206,12 @@ def forward(
if self.attention_mode == 'tvm':
v = v.float().contiguous()
attn += diagonaled_mm_tvm(attn_probs, v, self.attention_window, self.attention_dilation, True, 0, False)
else: # "sliding_chunks"
elif self.attention_mode == "sliding_chunks":
attn += sliding_chunks_matmul_pv(attn_probs, v, self.attention_window)
elif self.attention_mode == "sliding_chunks_no_overlap":
attn += sliding_chunks_no_overlap_matmul_pv(attn_probs, v, self.attention_window)
else:
raise False

attn = attn.type_as(hidden_states)
assert list(attn.size()) == [bsz, seq_len, self.num_heads, self.head_dim]
Expand Down
76 changes: 76 additions & 0 deletions longformer/longformer_encoder_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import List, Optional, Tuple, Dict
from torch import nn, Tensor
from longformer.longformer import LongformerSelfAttention
from transformers.modeling_bart import BartConfig, BartForConditionalGeneration


class LongformerEncoderDecoderForConditionalGeneration(BartForConditionalGeneration):
def __init__(self, config):
super().__init__(config)
if config.attention_mode == 'n2':
pass # do nothing, use BertSelfAttention instead
else:
for i, layer in enumerate(self.model.encoder.layers):
layer.self_attn = LongformerSelfAttentionForBart(config, layer_id=i)


class LongformerEncoderDecoderConfig(BartConfig):
def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None,
autoregressive: bool = False, attention_mode: str = 'sliding_chunks',
gradient_checkpointing: bool = False, **kwargs):
"""
Args:
attention_window: list of attention window sizes of length = number of layers.
window size = number of attention locations on each side.
For an affective window size of 512, use `attention_window=[256]*num_layers`
which is 256 on each side.
attention_dilation: list of attention dilation of length = number of layers.
attention dilation of `1` means no dilation.
autoregressive: do autoregressive attention or have attention of both sides
attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer
selfattention, 'sliding_chunks' for another implementation of Longformer selfattention
"""
super().__init__(**kwargs)
self.attention_window = attention_window
self.attention_dilation = attention_dilation
self.autoregressive = autoregressive
self.attention_mode = attention_mode
self.gradient_checkpointing = gradient_checkpointing
assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2']


class LongformerSelfAttentionForBart(nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.embed_dim = config.d_model
self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id)
self.output = nn.Linear(self.embed_dim, self.embed_dim)

def forward(
self,
query,
key: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
layer_state: Optional[Dict[str, Optional[Tensor]]] = None,
attn_mask: Optional[Tensor] = None,
need_weights=False,
output_attentions=False,
) -> Tuple[Tensor, Optional[Tensor]]:

tgt_len, bsz, embed_dim = query.size()
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
assert attn_mask is None

outputs = self.longformer_self_attn(
query.transpose(0, 1), # LongformerSelfAttention expects (bsz, seqlen, embd_dim)
attention_mask=key_padding_mask.unsqueeze(dim=1).unsqueeze(dim=1) * -1,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=output_attentions,
)

attn_output = self.output(outputs[0].transpose(0, 1))

return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None)
Loading