-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathcosformer.py
251 lines (225 loc) · 9.46 KB
/
cosformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import torch
import torch.nn.functional as F
import numpy as np
from torch import Tensor
from typing import Optional
from torch import nn
class CosformerAttention(nn.Module):
"""
cosformer attention in "cosFormer: Rethinking Softmax In Attention"
https://arxiv.org/abs/2202.08791
"""
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout_rate=0.0,
causal=False,
has_outproj=True,
act_fun="relu",
):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if kdim is not None else embed_dim
self.num_heads = num_heads
self.has_outproj = has_outproj
self.act_fun = self.get_act_fun(act_fun)
# q, k, v projection
self.k_proj = nn.Linear(self.kdim, embed_dim)
self.v_proj = nn.Linear(self.vdim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
# outprojection
self.out_proj = nn.Linear(embed_dim, embed_dim)
# dropout rate
self.dropout_rate = dropout_rate
# causal
self.causal = causal
assert (self.embed_dim % self.num_heads == 0), "embed_dim must be divisible by num_heads"
def get_index(self, seq_len):
index = np.pi / 2 * torch.arange(1, seq_len + 1).reshape(1, -1, 1)
return nn.Parameter(index, requires_grad=False)
def get_act_fun(self, act_fun):
if act_fun == "relu":
return F.relu
elif act_fun == "elu":
return 1 + F.elu
def forward(
self,
query: Tensor,
key: Optional[Tensor] = None,
value: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
eps: Optional[float] = 1e-6,
):
"""Input shape: Sequence x Batch x Embedding
Args:
query (Tensor): `(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
key (Tensor): `(S, N, E)` where S is the source sequence length, N is the batch size,
E is the embedding dimension.
value (Tensor): `(S, N, E)` where S is the source sequence length, N is the batch size,
E is the embedding dimension.
attn_mask (Optional[Tensor], optional): typically used to implement causal attention,
where the mask prevents the attention from looking forward in time (default: None).
"""
if key == None:
key = query
if value == None:
value = query
num_heads = self.num_heads
tgt_len, bsz, embed_dim = query.size()
src_len = key.size(0)
head_dim = embed_dim // num_heads
# get q, k, v
# (L, N, E)
q = self.q_proj(query)
# (S, N, E)
k = self.k_proj(key)
# (S, N, E)
v = self.v_proj(value)
# activation
q = self.act_fun(q)
k = self.act_fun(k)
# multihead reshape
# (N * h, L, d)
q = q.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
# (N * h, S, d)
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
# (N * h, S, d)
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
# cos transform
m = max(src_len, tgt_len)
# get index and send to cuda
weight_index = self.get_index(m).to(q)
# (N * h, L, 2 * d)
q_ = torch.cat([q * torch.sin(weight_index[:, :tgt_len, :] / m), q * torch.cos(weight_index[:, :tgt_len, :] / m)], dim=-1)
# (N * h, S, 2 * d)
k_ = torch.cat([k * torch.sin(weight_index[:, :src_len, :] / m), k * torch.cos(weight_index[:, :src_len, :] / m)], dim=-1)
if self.causal:
## Need to improve speed!
# (N * h, L, 2 * d) (N * h, L, d) -> (N * h, L, h, 2 * d, d)
kv_ = torch.einsum("nld,nlm->nldm", k_, v)
# (N * h, L, 2 * d, d) -> (N * h, L, 2 * d, d)
kv_cum = torch.cumsum(kv_, dim=1)
# (N * h, L, 2 * d) (N * h, L, 2 * d, d) -> (N * h, L, d)
qkv = torch.einsum("nld,nldm->nlm", q_, kv_cum)
# (N * h, L, 2 * d) -> (N * h, L, 2 * d)
k_cum = torch.cumsum(k_, dim=1)
# (N * h, L, 2 * d) (N * h, L, 2 * d) -> (N * h, L)
denom = torch.clamp_min(torch.einsum("nlm,nlm->nl", q_, k_cum), eps)
# (N * h, L, d) (N * h, L, 1) -> (N * h, L, d)
attn_output = qkv / denom.unsqueeze(-1)
# (N * h, L, d) -> (L, N * h, d) -> (L, N, E)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, -1)
else:
# (N * h, L, 2 * d) (N * h, L, d) -> (N * h, 2 * d, d)
kv_ = torch.einsum('nld,nlm->ndm', k_, v)
# (N * h, L, 2 * d) (N * h, 2 * d) -> (N * h, L)
z_ = 1 / torch.clamp_min(torch.einsum('nld,nd->nl', q_, torch.sum(k_, axis=1)), eps)
# (N * h, L, 2 * d) (N * h, d, 2 * d) (N * h, L) -> (N * h, L, d)
attn_output = torch.einsum('nld,ndm,nl->nlm', q_, kv_, z_)
# (N * h, L, d) -> (L, N * h, d) -> (L, N, E)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, -1)
# L, N, E
if self.has_outproj:
attn_output = self.out_proj(attn_output)
return attn_output
def left_product(
self,
query: Tensor,
key: Optional[Tensor] = None,
value: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
eps: Optional[float] = 1e-6,
):
"""Input shape: Sequence x Batch x Embedding
Args:
query (Tensor): `(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
key (Tensor): `(S, N, E)` where S is the source sequence length, N is the batch size,
E is the embedding dimension.
value (Tensor): `(S, N, E)` where S is the source sequence length, N is the batch size,
E is the embedding dimension.
attn_mask (Optional[Tensor], optional): typically used to implement causal attention,
where the mask prevents the attention from looking forward in time (default: None).
"""
# test for the correctness of the program
if key == None:
key = query
if value == None:
value = query
num_heads = self.num_heads
tgt_len, bsz, embed_dim = query.size()
src_len = key.size(0)
head_dim = embed_dim // num_heads
# get q, k, v
# (L, N, E)
q = self.q_proj(query)
# (S, N, E)
k = self.k_proj(key)
# (S, N, E)
v = self.v_proj(value)
# activation
q = self.act_fun(q)
k = self.act_fun(k)
# multihead reshape
# (N * h, L, d)
q = q.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
# (N * h, S, d)
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
# (N * h, S, d)
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
# cos transform
m = max(src_len, tgt_len)
# get index and send to cuda
weight_index = self.get_index(m).to(q)
# (N * h, L, 2 * d)
q_ = torch.cat([q * torch.sin(weight_index[:, :tgt_len, :] / m), q * torch.cos(weight_index[:, :tgt_len, :] / m)], dim=-1)
# (N * h, S, 2 * d)
k_ = torch.cat([k * torch.sin(weight_index[:, :src_len, :] / m), k * torch.cos(weight_index[:, :src_len, :] / m)], dim=-1)
# (N * h, L, d) (N * h, d, S) -> (N * h, L, S)
weights = torch.bmm(q_, k_.transpose(1, 2))
# mask
if self.causal:
weights = weights.masked_fill(attn_mask==float("-inf"), 0)
# (N * h, L, S) -> (N * h, L, S)
denom = torch.clamp_min(weights.sum(dim=-1, keepdim=True), eps)
# (N * h, L, S) (N * h, L, S) -> (N * h, L, S)
attn_weights = weights / denom
# (N * h, L, S) (N * h, S, d) -> (N * h, L, d)
attn_output = torch.bmm(attn_weights, v)
# (N * h, L, d) -> (L, N * h, d) -> (L, N, E)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, -1)
# L, N, E
if self.has_outproj:
attn_output = self.out_proj(attn_output)
return attn_output
def test(batch=2, tgt_len=10, src_len=20, embed_dim=128, num_heads=8, N=100, causal=False):
model = CosformerAttention(embed_dim=embed_dim, num_heads=num_heads, causal=causal)
diff = 0
if causal:
mask = (torch.triu(torch.ones(tgt_len, tgt_len)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf'))
else:
mask = None
for i in range(N):
query = torch.rand(tgt_len, batch, embed_dim)
key = torch.rand(src_len, batch, embed_dim)
value = torch.rand(src_len, batch, embed_dim)
left_res = model.left_product(query, key, value, mask)
right_res = model(query, key, value)
diff += torch.norm(left_res - right_res)
diff /= N
if causal:
print("Test result for causal model:")
else:
print("Test result for bidirectional model:")
print(f"The error of left multiplication and right multiplication is {diff}")
def main():
test(tgt_len=10, src_len=20, causal=False)
test(tgt_len=10, src_len=10, causal=True)
if __name__ == "__main__":
main()