forked from LennartKeller/roberta2longformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathroberta2nyströmformer.py
91 lines (72 loc) · 3.21 KB
/
roberta2nyströmformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from collections import OrderedDict
from tempfile import TemporaryDirectory
import torch
from transformers import (
AutoTokenizer,
NystromformerConfig,
NystromformerModel,
RobertaModel,
RobertaTokenizer,
)
def convert_roberta_to_nystromformer(
roberta_model, roberta_tokenizer, nystromformer_max_length: int = 50176
):
with TemporaryDirectory() as temp_dir:
roberta_tokenizer.model_max_length = nystromformer_max_length
roberta_tokenizer.save_pretrained(temp_dir)
nystromformer_tokenizer = AutoTokenizer.from_pretrained(temp_dir)
nystromformer_config = NystromformerConfig.from_dict(roberta_model.config.to_dict())
nystromformer_config.max_position_embeddings = nystromformer_max_length # - 2 (?)
nystromformer_model = NystromformerModel(nystromformer_config)
# Copy encoder weights
nystromformer_model.encoder.load_state_dict(
roberta_model.encoder.state_dict(), strict=False
)
# ------------#
# Embeddings #
# ------------#
# There are two types of embeddings:
# 1. Token embeddings
# We can simply copy the token embeddings.
# We have to resize the token embeddings upfront, to make load_state_dict work.
nystromformer_model.resize_token_embeddings(len(roberta_tokenizer))
roberta_embeddings_parameters = roberta_model.embeddings.state_dict()
embedding_parameters2copy = []
for key, item in roberta_embeddings_parameters.items():
if not "position" in key and not "token_type_embeddings" in key:
embedding_parameters2copy.append((key, item))
# 2. Positional embeddings
# The positional embeddings are repeatedly copied over
# to longformer to match the new max_seq_length
roberta_pos_embs = roberta_model.embeddings.state_dict()[
"position_embeddings.weight"
][2:]
roberta_pos_embs_extra = roberta_model.embeddings.state_dict()[
"position_embeddings.weight"
][:2]
assert (
roberta_pos_embs.size(0) < nystromformer_max_length
), "Longformer sequence length has to be longer than roberta original sequence length"
# Figure out how many time we need to copy the original embeddings
n_copies = round(nystromformer_max_length / roberta_pos_embs.size(0))
# Copy the embeddings and handle the last missing ones.
longformer_pos_embs = roberta_pos_embs.repeat((n_copies, 1))
n_pos_embs_left = nystromformer_max_length - longformer_pos_embs.size(0)
longformer_pos_embs = torch.cat(
[longformer_pos_embs, roberta_pos_embs[:n_pos_embs_left]], dim=0
)
# Add the last extra embeddings.
# Nystromformer like Bigbird transformers implementation does not shift position_ids,
# so we pad the position embeddings at the end.
longformer_pos_embs = torch.cat(
[roberta_pos_embs_extra, longformer_pos_embs], dim=0
)
embedding_parameters2copy.append(
("position_embeddings.weight", longformer_pos_embs)
)
# Load the embedding weights into the longformer model
embedding_parameters2copy = OrderedDict(embedding_parameters2copy)
nystromformer_model.embeddings.load_state_dict(
embedding_parameters2copy, strict=False
)
return nystromformer_model, nystromformer_tokenizer