-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathmodules.py
945 lines (822 loc) · 30.6 KB
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from evo.tensor import symmetrize, apc
from product_key_memory import PKM
from functools import partial
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different
(and gives slightly different results):
0.5 * x * (
1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))
)
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
class TransformerLayer(nn.Module):
"""Transformer layer block."""
def __init__(
self,
embed_dim: int,
ffn_embed_dim: int,
attention_heads: int,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
attention_type: str = "standard",
performer_attention_features: int = 256,
):
super().__init__()
self.embed_dim = embed_dim
self.ffn_embed_dim = ffn_embed_dim
self.attention_heads = attention_heads
self.attention_type = attention_type
self.attention_dropout = attention_dropout
self.performer_attention_features = performer_attention_features
self.dropout = nn.Dropout(dropout)
self.activation_dropout = nn.Dropout(activation_dropout)
self.self_attn = self.build_self_attention()
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim)
self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def build_self_attention(self):
if self.attention_type == "standard":
return MultiheadAttention(
self.embed_dim,
self.attention_heads,
dropout=self.attention_dropout,
)
elif self.attention_type == "performer":
return PerformerAttention(
self.embed_dim,
self.attention_heads,
num_features=self.performer_attention_features,
dropout=self.attention_dropout,
)
else:
raise ValueError(f"Unrecognized attention type {self.attention_type}")
def forward(
self,
x,
self_attn_mask=None,
self_attn_padding_mask=None,
need_head_weights=False,
):
residual = x
x = self.self_attn_layer_norm(x)
x, attn = self.self_attn(
x,
key_padding_mask=self_attn_padding_mask,
need_weights=need_head_weights,
need_head_weights=need_head_weights,
attn_mask=self_attn_mask,
)
x = self.dropout(x)
x = residual + x
residual = x
x = self.final_layer_norm(x)
x = gelu(self.fc1(x))
x = self.activation_dropout(x)
x = self.fc2(x)
x = self.dropout(x)
x = residual + x
return x, attn
class PKMLayer(nn.Module):
"""Transformer layer block."""
def __init__(
self,
embed_dim: int,
ffn_embed_dim: int,
attention_heads: int,
pkm_attention_heads: int,
pkm_dim_head: int,
num_product_keys: int = 128,
pkm_topk: int = 32,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
attention_type: str = "standard",
performer_attention_features: int = 256,
):
super().__init__()
self.embed_dim = embed_dim
self.ffn_embed_dim = ffn_embed_dim
self.attention_heads = attention_heads
self.pkm_attention_heads = pkm_attention_heads
self.num_product_keys = num_product_keys
self.pkm_topk = pkm_topk
self.pkm_dim_head = pkm_dim_head
self.attention_type = attention_type
self.attention_dropout = attention_dropout
self.performer_attention_features = performer_attention_features
self.self_attn = self.build_self_attention()
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.pkm = PKM(
self.embed_dim,
self.pkm_attention_heads,
self.num_product_keys,
self.pkm_topk,
self.pkm_dim_head,
)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def build_self_attention(self):
if self.attention_type == "standard":
return MultiheadAttention(
self.embed_dim,
self.attention_heads,
dropout=self.attention_dropout,
)
elif self.attention_type == "performer":
return PerformerAttention(
self.embed_dim,
self.attention_heads,
num_features=self.performer_attention_features,
dropout=self.attention_dropout,
)
else:
raise ValueError(f"Unrecognized attention type {self.attention_type}")
def forward(
self,
x,
self_attn_mask=None,
self_attn_padding_mask=None,
need_head_weights=False,
):
residual = x
x = self.self_attn_layer_norm(x)
x, attn = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
need_weights=True,
need_head_weights=need_head_weights,
attn_mask=self_attn_mask,
)
x = residual + x
residual = x
x = self.final_layer_norm(x)
x = self.pkm(x)
x = residual + x
return x, attn
class AxialTransformerLayer(nn.Module):
"""Implements an Axial MSA Transformer block."""
def __init__(
self,
embedding_dim: int = 768,
ffn_embedding_dim: int = 3072,
num_attention_heads: int = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
max_tokens_per_msa: int = 2 ** 14,
) -> None:
super().__init__()
# Initialize parameters
self.embedding_dim = embedding_dim
self.dropout_prob = dropout
row_self_attention = RowSelfAttention(
embedding_dim,
num_attention_heads,
dropout=dropout,
max_tokens_per_msa=max_tokens_per_msa,
)
column_self_attention = ColumnSelfAttention(
embedding_dim,
num_attention_heads,
dropout=dropout,
max_tokens_per_msa=max_tokens_per_msa,
)
feed_forward_layer = FeedForwardNetwork(
embedding_dim,
ffn_embedding_dim,
activation_dropout=activation_dropout,
max_tokens_per_msa=max_tokens_per_msa,
)
self.row_self_attention = self.build_residual(row_self_attention)
self.column_self_attention = self.build_residual(column_self_attention)
self.feed_forward_layer = self.build_residual(feed_forward_layer)
def build_residual(self, layer: nn.Module):
return NormalizedResidualBlock(
layer,
self.embedding_dim,
self.dropout_prob,
)
def forward(
self,
x: torch.Tensor,
self_attn_mask: Optional[torch.Tensor] = None,
self_attn_padding_mask: Optional[torch.Tensor] = None,
need_head_weights: bool = False,
):
"""
LayerNorm is applied either before or after the self-attention/ffn
modules similar to the original Transformer implementation.
"""
x, row_attn = self.row_self_attention(
x,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
)
x, column_attn = self.column_self_attention(
x,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
)
x = self.feed_forward_layer(x)
if need_head_weights:
return x, column_attn, row_attn
else:
return x
class LearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
Padding ids are ignored by either offsetting based on padding_idx
or by setting padding_idx to None and ensuring that the appropriate
position ids are passed to the forward function.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
if padding_idx is not None:
num_embeddings_ = num_embeddings + padding_idx + 1
else:
num_embeddings_ = num_embeddings
super().__init__(num_embeddings_, embedding_dim, padding_idx)
self.max_positions = num_embeddings
def forward(self, input: torch.Tensor):
"""Input is expected to be of size [bsz x seqlen]."""
mask = input.ne(self.padding_idx).int()
positions = (
torch.cumsum(mask, dim=1).type_as(mask) * mask
).long() + self.padding_idx
return F.embedding(
positions,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
class RobertaLMHead(nn.Module):
"""Head for masked language modeling."""
def __init__(self, embed_dim, output_dim, weight):
super().__init__()
self.dense = nn.Linear(embed_dim, embed_dim)
self.layer_norm = nn.LayerNorm(embed_dim)
self.weight = weight
self.bias = nn.Parameter(torch.zeros(output_dim))
def forward(self, features):
x = self.dense(features)
x = gelu(x)
x = self.layer_norm(x)
# project back to size of vocabulary with bias
x = F.linear(x, self.weight) + self.bias
return x
class ContactPredictionHead(nn.Module):
"""Performs symmetrization, apc, and computes a logistic regression on the output
features
"""
def __init__(
self,
in_features: int,
prepend_bos: bool,
append_eos: bool,
bias=True,
eos_idx: Optional[int] = None,
):
super().__init__()
self.in_features = in_features
self.prepend_bos = prepend_bos
self.append_eos = append_eos
if append_eos and eos_idx is None:
raise ValueError(
"Using an alphabet with eos token, but no eos token was passed in."
)
self.eos_idx = eos_idx
self.regression = nn.Linear(in_features, 1, bias)
self.activation = nn.Sigmoid()
def forward(self, tokens, attentions):
# remove eos token attentions
if self.append_eos:
eos_mask = tokens.ne(self.eos_idx).to(attentions)
eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
attentions = attentions * eos_mask[:, None, None, :, :]
attentions = attentions[..., :-1, :-1]
# remove cls token attentions
if self.prepend_bos:
attentions = attentions[..., 1:, 1:]
batch_size, layers, heads, seqlen, _ = attentions.size()
attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
# features: B x C x T x T
attentions = attentions.to(
next(self.parameters())
) # attentions always float32, may need to convert to float16
attentions = apc(symmetrize(attentions))
attentions = attentions.permute(0, 2, 3, 1)
return self.activation(self.regression(attentions).squeeze(3))
class NormalizedResidualBlock(nn.Module):
def __init__(
self,
layer: nn.Module,
embedding_dim: int,
dropout: float = 0.1,
):
super().__init__()
self.embedding_dim = embedding_dim
self.layer = layer
self.dropout_module = nn.Dropout(
dropout,
)
self.layer_norm = nn.LayerNorm(self.embedding_dim)
def forward(self, x, *args, **kwargs):
residual = x
x = self.layer_norm(x)
outputs = self.layer(x, *args, **kwargs)
if isinstance(outputs, tuple):
x, *out = outputs
else:
x = outputs
out = None
x = self.dropout_module(x)
x = residual + x
if out is not None:
return (x,) + tuple(out)
else:
return x
class FeedForwardNetwork(nn.Module):
def __init__(
self,
embedding_dim: int,
ffn_embedding_dim: int,
activation_dropout: float = 0.1,
max_tokens_per_msa: int = 2 ** 14,
):
super().__init__()
self.embedding_dim = embedding_dim
self.ffn_embedding_dim = ffn_embedding_dim
self.max_tokens_per_msa = max_tokens_per_msa
self.activation_fn = nn.GELU()
self.activation_dropout_module = nn.Dropout(
activation_dropout,
)
self.fc1 = nn.Linear(embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, embedding_dim)
def forward(self, x):
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
return x
class MultiheadAttention(nn.Module):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.reset_parameters()
self.enable_torch_version = hasattr(F, "multi_head_attention_forward")
if self.enable_torch_version:
self._attn_fn = partial(
F.multi_head_attention_forward, # type: ignore
embed_dim_to_check=self.embed_dim,
num_heads=self.num_heads,
in_proj_weight=torch.empty([0]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=self.dropout,
use_separate_proj_weight=True,
)
def attn_fn(
self,
query,
key,
value,
key_padding_mask: Optional[torch.Tensor] = None,
need_weights: bool = False,
attn_mask: Optional[torch.Tensor] = None,
):
return self._attn_fn(
query,
key,
value,
in_proj_bias=torch.cat(
(self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)
),
training=self.training,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
out_proj_weight=self.out_proj.weight,
out_proj_bias=self.out_proj.bias,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
)
def reset_parameters(self):
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
def forward(
self,
x: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None,
need_weights: bool = False,
attn_mask: Optional[torch.Tensor] = None,
need_head_weights: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
tgt_len, bsz, embed_dim = x.size()
assert embed_dim == self.embed_dim
if (
self.enable_torch_version
# A workaround for quantization to work. Otherwise JIT compilation
# treats bias in linear module as method.
and not torch.jit.is_scripting()
and not need_head_weights
):
return self.attn_fn(
query=x,
key=x,
value=x,
key_padding_mask=key_padding_mask,
need_weights=need_weights,
attn_mask=attn_mask,
)
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
q *= self.scaling
q = (
q.contiguous()
.view(tgt_len, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
src_len = k.size(1)
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
if self.onnx_trace:
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
attn_weights += attn_mask
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf")
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights_float = F.softmax(
attn_weights, dim=-1, dtype=torch.float32 # type: ignore
)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = F.dropout(
attn_weights_float.type_as(attn_weights),
p=self.dropout,
training=self.training,
)
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
attn_weights: Optional[torch.Tensor] = None # type: ignore
if need_weights:
attn_weights = attn_weights_float.view(
bsz, self.num_heads, tgt_len, src_len
).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
return attn, attn_weights
class PerformerAttention(MultiheadAttention):
def __init__(
self,
embed_dim: int,
num_heads: int,
num_features: int,
dropout: float = 0.0,
bias: bool = True,
):
from performer_pytorch import FastAttention
super().__init__(embed_dim, num_heads, dropout, bias)
self._attn_fn = FastAttention(dim_heads=self.head_dim, nb_features=num_features)
def attn_fn(self, query, key, value):
return self._attn_fn(query, key, value)
def forward(
self,
x: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None,
need_weights: bool = False,
attn_mask: Optional[torch.Tensor] = None,
need_head_weights: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
from einops import rearrange
seqlen, bsz, embed_dim = x.size()
q = self.q_proj(x) # [T x B x D]
k = self.k_proj(x) # [...]
v = self.v_proj(x) # [...]
q, k, v = map(
lambda t: rearrange(t, "t b (h d) -> b h t d", h=self.num_heads), (q, k, v)
)
if key_padding_mask is not None:
mask = key_padding_mask[:, None, :, None]
v.masked_fill_(mask, 0)
if attn_mask is not None:
raise NotImplementedError
attn = self.attn_fn(q, k, v)
attn = rearrange(attn, "b h t d -> t b (h d)")
attn = self.out_proj(attn)
if need_weights or need_head_weights:
v_pos = torch.eye(seqlen, dtype=v.dtype, device=v.device)[
None, None
].repeat(bsz, self.num_heads, 1, 1)
attn_weights = self.attn_fn(q, k, v_pos).transpose(1, 0)
if not need_head_weights:
attn_weights = attn_weights.mean(0)
else:
attn_weights = None
return attn, attn_weights
class RowSelfAttention(nn.Module):
"""Compute self-attention over rows of a 2D input."""
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
max_tokens_per_msa: int = 2 ** 16,
):
super().__init__()
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.scaling = self.head_dim ** -0.5
self.max_tokens_per_msa = max_tokens_per_msa
self.attn_shape = "hnij"
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.dropout_module = nn.Dropout(dropout)
def align_scaling(self, q):
num_rows = q.size(0)
return self.scaling / math.sqrt(num_rows)
def _batched_forward(
self,
x,
self_attn_mask=None,
self_attn_padding_mask=None,
):
num_rows, num_cols, batch_size, embed_dim = x.size()
max_rows = max(1, self.max_tokens_per_msa // num_cols)
attns = 0
scaling = self.align_scaling(x)
for start in range(0, num_rows, max_rows):
attn_weights = self.compute_attention_weights(
x[start : start + max_rows],
scaling,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask[
:, start : start + max_rows
]
if self_attn_padding_mask is not None
else None,
)
attns += attn_weights
attn_probs = attns.softmax(-1)
attn_probs = self.dropout_module(attn_probs)
outputs = []
for start in range(0, num_rows, max_rows):
output = self.compute_attention_update(
x[start : start + max_rows], attn_probs
)
outputs.append(output)
output = torch.cat(outputs, 0)
return output, attn_probs
def compute_attention_weights(
self,
x,
scaling: float,
self_attn_mask=None,
self_attn_padding_mask=None,
):
num_rows, num_cols, batch_size, embed_dim = x.size()
q = self.q_proj(x).view(
num_rows, num_cols, batch_size, self.num_heads, self.head_dim
)
k = self.k_proj(x).view(
num_rows, num_cols, batch_size, self.num_heads, self.head_dim
)
q *= scaling
if self_attn_padding_mask is not None:
# Zero out any padded aligned positions - this is important since
# we take a sum across the alignment axis.
q *= 1 - self_attn_padding_mask.permute(1, 2, 0).unsqueeze(3).unsqueeze(
4
).to(q)
attn_weights = torch.einsum(f"rinhd,rjnhd->{self.attn_shape}", q, k)
if self_attn_mask is not None:
raise NotImplementedError
# Mask Size: [B x R x C], Weights Size: [H x B x C x C]
if self_attn_padding_mask is not None:
attn_weights = attn_weights.masked_fill(
self_attn_padding_mask[:, 0].unsqueeze(0).unsqueeze(2),
-10000,
)
return attn_weights
def compute_attention_update(
self,
x,
attn_probs,
):
num_rows, num_cols, batch_size, embed_dim = x.size()
v = self.v_proj(x).view(
num_rows, num_cols, batch_size, self.num_heads, self.head_dim
)
context = torch.einsum(f"{self.attn_shape},rjnhd->rinhd", attn_probs, v)
context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
output = self.out_proj(context)
return output
def forward(
self,
x,
self_attn_mask=None,
self_attn_padding_mask=None,
):
num_rows, num_cols, batch_size, embed_dim = x.size()
if (
num_rows * num_cols > self.max_tokens_per_msa
) and not torch.is_grad_enabled():
return self._batched_forward(x, self_attn_mask, self_attn_padding_mask)
else:
scaling = self.align_scaling(x)
attn_weights = self.compute_attention_weights(
x, scaling, self_attn_mask, self_attn_padding_mask
)
attn_probs = attn_weights.softmax(-1)
attn_probs = self.dropout_module(attn_probs)
output = self.compute_attention_update(x, attn_probs)
return output, attn_probs
class ColumnSelfAttention(nn.Module):
"""Compute self-attention over columns of a 2D input."""
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
max_tokens_per_msa: int = 2 ** 16,
):
super().__init__()
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.scaling = self.head_dim ** -0.5
self.max_tokens_per_msa = max_tokens_per_msa
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.dropout_module = nn.Dropout(dropout)
def _batched_forward(
self,
x,
self_attn_mask=None,
self_attn_padding_mask=None,
):
num_rows, num_cols, batch_size, embed_dim = x.size()
max_cols = max(1, self.max_tokens_per_msa // num_rows)
outputs = []
attns = []
for start in range(0, num_cols, max_cols):
output, attn = self(
x[:, start : start + max_cols],
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask[
:, :, start : start + max_cols
]
if self_attn_padding_mask is not None
else None,
)
outputs.append(output)
attns.append(attn)
output = torch.cat(outputs, 1)
attns = torch.cat(attns, 1)
return output, attns
def compute_attention_update(
self,
x,
self_attn_mask=None,
self_attn_padding_mask=None,
):
num_rows, num_cols, batch_size, embed_dim = x.size()
if num_rows == 1:
# if there is only 1 position, this is equivalent and doesn't break with
# padding
attn_probs = torch.ones(
self.num_heads,
num_cols,
batch_size,
num_rows,
num_rows,
device=x.device,
dtype=x.dtype,
)
output = self.out_proj(self.v_proj(x))
else:
q = self.q_proj(x).view(
num_rows, num_cols, batch_size, self.num_heads, self.head_dim
)
k = self.k_proj(x).view(
num_rows, num_cols, batch_size, self.num_heads, self.head_dim
)
v = self.v_proj(x).view(
num_rows, num_cols, batch_size, self.num_heads, self.head_dim
)
q *= self.scaling
attn_weights = torch.einsum("icnhd,jcnhd->hcnij", q, k)
if self_attn_mask is not None:
raise NotImplementedError
if self_attn_padding_mask is not None:
attn_weights = attn_weights.masked_fill(
self_attn_padding_mask.permute(2, 0, 1).unsqueeze(0).unsqueeze(3),
-10000,
)
attn_probs = attn_weights.softmax(-1)
attn_probs = self.dropout_module(attn_probs)
context = torch.einsum("hcnij,jcnhd->icnhd", attn_probs, v)
context = context.contiguous().view(
num_rows, num_cols, batch_size, embed_dim
)
output = self.out_proj(context)
return output, attn_probs
def forward(
self,
x,
self_attn_mask=None,
self_attn_padding_mask=None,
):
num_rows, num_cols, batch_size, embed_dim = x.size()
# if False and num_rows * num_cols > 2 ** 14 and not torch.is_grad_enabled():
if (
num_rows * num_cols
) > self.max_tokens_per_msa and not torch.is_grad_enabled():
return self._batched_forward(
x,
self_attn_mask,
self_attn_padding_mask,
)
else:
return self.compute_attention_update(
x, self_attn_mask, self_attn_padding_mask
)