forked from yiqings/HEGnnEnhanceCnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathft_transformer.py
571 lines (490 loc) · 21.5 KB
/
ft_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F
import math
import enum
import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
import warnings
ModuleType = Union[str, Callable[..., nn.Module]]
_INTERNAL_ERROR_MESSAGE = 'Internal error. Please, open an issue.'
def _is_glu_activation(activation: ModuleType):
return (
isinstance(activation, str)
and activation.endswith('GLU')
or activation in [ReGLU, GEGLU]
)
def _make_nn_module(module_type: ModuleType, *args) -> nn.Module:
if isinstance(module_type, str):
if module_type == 'ReGLU':
return ReGLU()
elif module_type == 'GEGLU':
return GEGLU()
else:
try:
cls = getattr(nn, module_type)
except AttributeError as err:
raise ValueError(
f'Failed to construct the module {module_type} with the arguments {args}'
) from err
return cls(*args)
else:
return module_type(*args)
def _all_or_none(values):
return all(x is None for x in values) or all(x is not None for x in values)
def reglu(x: Tensor) -> Tensor:
assert x.shape[-1] % 2 == 0
a, b = x.chunk(2, dim=-1)
return a * F.relu(b)
def geglu(x: Tensor) -> Tensor:
assert x.shape[-1] % 2 == 0
a, b = x.chunk(2, dim=-1)
return a * F.gelu(b)
class ReGLU(nn.Module):
def forward(self, x: Tensor) -> Tensor:
return reglu(x)
class GEGLU(nn.Module):
"""The GEGLU activation function from [shazeer2020glu].
Examples:
.. testcode::
module = GEGLU()
x = torch.randn(3, 4)
assert module(x).shape == (3, 2)
References:
* [shazeer2020glu] Noam Shazeer, "GLU Variants Improve Transformer", 2020
"""
def forward(self, x: Tensor) -> Tensor:
return geglu(x)
class _CLSToken(nn.Module):
"""[CLS]-token for BERT-like inference.
To learn about the [CLS]-based inference, see [devlin2018bert].
When used as a module, the [CLS]-token is appended **to the end** of each item in
the batch.
Examples:
.. testcode::
batch_size = 2
n_tokens = 3
d_token = 4
cls_token = _CLSToken(d_token, 'uniform')
x = torch.randn(batch_size, n_tokens, d_token)
x = cls_token(x)
assert x.shape == (batch_size, n_tokens + 1, d_token)
assert (x[:, -1, :] == cls_token.expand(len(x))).all()
References:
* [devlin2018bert] Jacob Devlin, Ming-Wei Chang, Kenton Lee, Kristina Toutanova "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding" 2018
"""
def __init__(self, d_token: int, initialization: str) -> None:
"""
Args:
d_token: the size of token
initialization: initialization policy for parameters. Must be one of
:code:`['uniform', 'normal']`. Let :code:`s = d ** -0.5`. Then, the
corresponding distributions are :code:`Uniform(-s, s)` and :code:`Normal(0, s)`. In
the paper [gorishniy2021revisiting], the 'uniform' initialization was
used.
References:
* [gorishniy2021revisiting] Yury Gorishniy, Ivan Rubachev, Valentin Khrulkov, Artem Babenko "Revisiting Deep Learning Models for Tabular Data", 2021
"""
super().__init__()
initialization_ = _TokenInitialization.from_str(initialization)
self.weight = nn.Parameter(Tensor(d_token))
initialization_.apply(self.weight, d_token)
def expand(self, *leading_dimensions: int) -> Tensor:
"""Expand (repeat) the underlying [CLS]-token to a tensor with the given leading dimensions.
A possible use case is building a batch of [CLS]-tokens. See `_CLSToken` for
examples of usage.
Note:
Under the hood, the `torch.Tensor.expand` method is applied to the
underlying :code:`weight` parameter, so gradients will be propagated as
expected.
Args:
leading_dimensions: the additional new dimensions
Returns:
tensor of the shape :code:`(*leading_dimensions, len(self.weight))`
"""
if not leading_dimensions:
return self.weight
new_dims = (1,) * (len(leading_dimensions) - 1)
return self.weight.view(*new_dims, -1).expand(*leading_dimensions, -1)
def forward(self, x: Tensor) -> Tensor:
"""Append self **to the end** of each item in the batch (see `_CLSToken`)."""
return torch.cat([x, self.expand(len(x), 1)], dim=1)
class _TokenInitialization(enum.Enum):
UNIFORM = 'uniform'
NORMAL = 'normal'
@classmethod
def from_str(cls, initialization: str) -> '_TokenInitialization':
try:
return cls(initialization)
except ValueError:
valid_values = [x.value for x in _TokenInitialization]
raise ValueError(f'initialization must be one of {valid_values}')
def apply(self, x: Tensor, d: int) -> None:
d_sqrt_inv = 1 / math.sqrt(d)
if self == _TokenInitialization.UNIFORM:
# used in the paper "Revisiting Deep Learning Models for Tabular Data";
# is equivalent to `nn.init.kaiming_uniform_(x, a=math.sqrt(5))` (which is
# used by torch to initialize nn.Linear.weight, for example)
nn.init.uniform_(x, a=-d_sqrt_inv, b=d_sqrt_inv)
elif self == _TokenInitialization.NORMAL:
nn.init.normal_(x, std=d_sqrt_inv)
class MultiheadAttention(nn.Module):
"""Multihead Attention (self-/cross-) with optional 'linear' attention.
To learn more about Multihead Attention, see [devlin2018bert]. See the implementation
of `Transformer` and the examples below to learn how to use the compression technique
from [wang2020linformer] to speed up the module when the number of tokens is large.
Examples:
.. testcode::
n_objects, n_tokens, d_token = 2, 3, 12
n_heads = 6
a = torch.randn(n_objects, n_tokens, d_token)
b = torch.randn(n_objects, n_tokens * 2, d_token)
module = MultiheadAttention(
d_token=d_token, n_heads=n_heads, dropout=0.2, bias=True, initialization='kaiming'
)
# self-attention
x, attention_stats = module(a, a, None, None)
assert x.shape == a.shape
assert attention_stats['attention_probs'].shape == (n_objects * n_heads, n_tokens, n_tokens)
assert attention_stats['attention_logits'].shape == (n_objects * n_heads, n_tokens, n_tokens)
# cross-attention
assert module(a, b, None, None)
# Linformer self-attention with the 'headwise' sharing policy
k_compression = torch.nn.Linear(n_tokens, n_tokens // 4)
v_compression = torch.nn.Linear(n_tokens, n_tokens // 4)
assert module(a, a, k_compression, v_compression)
# Linformer self-attention with the 'key-value' sharing policy
kv_compression = torch.nn.Linear(n_tokens, n_tokens // 4)
assert module(a, a, kv_compression, kv_compression)
References:
* [devlin2018bert] Jacob Devlin, Ming-Wei Chang, Kenton Lee, Kristina Toutanova "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding" 2018
* [wang2020linformer] Sinong Wang, Belinda Z. Li, Madian Khabsa, Han Fang, Hao Ma "Linformer: Self-Attention with Linear Complexity", 2020
"""
def __init__(
self,
*,
d_token: int,
n_heads: int,
dropout: float,
bias: bool,
initialization: str,
) -> None:
"""
Args:
d_token: the token size. Must be a multiple of :code:`n_heads`.
n_heads: the number of heads. If greater than 1, then the module will have
an addition output layer (so called "mixing" layer).
dropout: dropout rate for the attention map. The dropout is applied to
*probabilities* and do not affect logits.
bias: if `True`, then input (and output, if presented) layers also have bias.
`True` is a reasonable default choice.
initialization: initialization for input projection layers. Must be one of
:code:`['kaiming', 'xavier']`. `kaiming` is a reasonable default choice.
Raises:
AssertionError: if requirements for the inputs are not met.
"""
super().__init__()
if n_heads > 1:
assert d_token % n_heads == 0, 'd_token must be a multiple of n_heads'
assert initialization in ['kaiming', 'xavier']
self.W_q = nn.Linear(d_token, d_token, bias)
self.W_k = nn.Linear(d_token, d_token, bias)
self.W_v = nn.Linear(d_token, d_token, bias)
self.W_out = nn.Linear(d_token, d_token, bias) if n_heads > 1 else None
self.n_heads = n_heads
self.dropout = nn.Dropout(dropout) if dropout else None
for m in [self.W_q, self.W_k, self.W_v]:
# the "xavier" branch tries to follow torch.nn.MultiheadAttention;
# the second condition checks if W_v plays the role of W_out; the latter one
# is initialized with Kaiming in torch
if initialization == 'xavier' and (
m is not self.W_v or self.W_out is not None
):
# gain is needed since W_qkv is represented with 3 separate layers (it
# implies different fan_out)
nn.init.xavier_uniform_(m.weight, gain=1 / math.sqrt(2))
if m.bias is not None:
nn.init.zeros_(m.bias)
if self.W_out is not None:
nn.init.zeros_(self.W_out.bias)
def _reshape(self, x: Tensor) -> Tensor:
batch_size, n_tokens, d = x.shape
d_head = d // self.n_heads
return (
x.reshape(batch_size, n_tokens, self.n_heads, d_head)
.transpose(1, 2)
.reshape(batch_size * self.n_heads, n_tokens, d_head)
)
def forward(
self,
x_q: Tensor,
x_kv: Tensor,
key_compression: Optional[nn.Linear],
value_compression: Optional[nn.Linear],
) -> Tuple[Tensor, Dict[str, Tensor]]:
"""Perform the forward pass.
Args:
x_q: query tokens
x_kv: key-value tokens
key_compression: Linformer-style compression for keys
value_compression: Linformer-style compression for values
Returns:
(tokens, attention_stats)
"""
assert _all_or_none(
[key_compression, value_compression]
), 'If key_compression is (not) None, then value_compression must (not) be None'
q, k, v = self.W_q(x_q), self.W_k(x_kv), self.W_v(x_kv)
for tensor in [q, k, v]:
assert tensor.shape[-1] % self.n_heads == 0, _INTERNAL_ERROR_MESSAGE
if key_compression is not None:
k = key_compression(k.transpose(1, 2)).transpose(1, 2)
v = value_compression(v.transpose(1, 2)).transpose(1, 2) # type: ignore
batch_size = len(q)
d_head_key = k.shape[-1] // self.n_heads
d_head_value = v.shape[-1] // self.n_heads
n_q_tokens = q.shape[1]
q = self._reshape(q)
k = self._reshape(k)
attention_logits = q @ k.transpose(1, 2) / math.sqrt(d_head_key)
attention_probs = F.softmax(attention_logits, dim=-1)
if self.dropout is not None:
attention_probs = self.dropout(attention_probs)
x = attention_probs @ self._reshape(v)
x = (
x.reshape(batch_size, self.n_heads, n_q_tokens, d_head_value)
.transpose(1, 2)
.reshape(batch_size, n_q_tokens, self.n_heads * d_head_value)
)
if self.W_out is not None:
x = self.W_out(x)
return x, {
'attention_logits': attention_logits,
'attention_probs': attention_probs,
}
class FT_Transformer(nn.Module):
"""Transformer with extra features.
This module is the backbone of `FTTransformer`."""
WARNINGS = {'first_prenormalization': True, 'prenormalization': True}
class FFN(nn.Module):
"""The Feed-Forward Network module used in every `Transformer` block."""
def __init__(
self,
*,
d_token: int,
d_hidden: int,
bias_first: bool,
bias_second: bool,
dropout: float,
activation: ModuleType,
):
super().__init__()
self.linear_first = nn.Linear(
d_token,
d_hidden * (2 if _is_glu_activation(activation) else 1),
bias_first,
)
self.activation = _make_nn_module(activation)
self.dropout = nn.Dropout(dropout)
self.linear_second = nn.Linear(d_hidden, d_token, bias_second)
def forward(self, x: Tensor) -> Tensor:
x = self.linear_first(x)
x = self.activation(x)
x = self.dropout(x)
x = self.linear_second(x)
return x
class Head(nn.Module):
"""The final module of the `Transformer` that performs BERT-like inference."""
def __init__(
self,
*,
d_in: int,
bias: bool,
activation: ModuleType,
normalization: ModuleType,
d_out: int,
):
super().__init__()
self.normalization = _make_nn_module(normalization, d_in)
self.activation = _make_nn_module(activation)
self.linear = nn.Linear(d_in, d_out, bias)
def forward(self, x: Tensor) -> Tensor:
x = x[:, -1]
x = self.normalization(x)
x = self.activation(x)
x = self.linear(x)
return x
def __init__(
self,
*,
d_token: int,
n_blocks: int,
attention_n_heads: int,
attention_dropout: float,
attention_initialization: str,
attention_normalization: str,
ffn_d_hidden: int,
ffn_dropout: float,
ffn_activation: str,
ffn_normalization: str,
residual_dropout: float,
prenormalization: bool,
first_prenormalization: bool,
last_layer_query_idx: Union[None, List[int], slice],
n_tokens: Optional[int],
kv_compression_ratio: Optional[float],
kv_compression_sharing: Optional[str],
head_activation: ModuleType,
head_normalization: ModuleType,
d_out: int,
projection: Optional[bool] = False,
) -> None:
super().__init__()
if isinstance(last_layer_query_idx, int):
raise ValueError(
'last_layer_query_idx must be None, list[int] or slice. '
f'Do you mean last_layer_query_idx=[{last_layer_query_idx}] ?'
)
if not prenormalization:
assert (
not first_prenormalization
), 'If `prenormalization` is False, then `first_prenormalization` must be False'
assert _all_or_none([n_tokens, kv_compression_ratio, kv_compression_sharing]), (
'If any of the following arguments is (not) None, then all of them must (not) be None: '
'n_tokens, kv_compression_ratio, kv_compression_sharing'
)
assert kv_compression_sharing in [None, 'headwise', 'key-value', 'layerwise']
if not prenormalization:
if self.WARNINGS['prenormalization']:
warnings.warn(
'prenormalization is set to False. Are you sure about this? '
'The training can become less stable. '
'You can turn off this warning by tweaking the '
'rtdl.Transformer.WARNINGS dictionary.',
UserWarning,
)
assert (
not first_prenormalization
), 'If prenormalization is False, then first_prenormalization is ignored and must be set to False'
if (
prenormalization
and first_prenormalization
and self.WARNINGS['first_prenormalization']
):
warnings.warn(
'first_prenormalization is set to True. Are you sure about this? '
'For example, the vanilla FTTransformer with '
'first_prenormalization=True performs SIGNIFICANTLY worse. '
'You can turn off this warning by tweaking the '
'rtdl.Transformer.WARNINGS dictionary.',
UserWarning,
)
time.sleep(3)
def make_kv_compression():
assert (
n_tokens and kv_compression_ratio
), _INTERNAL_ERROR_MESSAGE # for mypy
# https://github.com/pytorch/fairseq/blob/1bba712622b8ae4efb3eb793a8a40da386fe11d0/examples/linformer/linformer_src/modules/multihead_linear_attention.py#L83
return nn.Linear(n_tokens, int(n_tokens * kv_compression_ratio), bias=False)
self.shared_kv_compression = (
make_kv_compression()
if kv_compression_ratio and kv_compression_sharing == 'layerwise'
else None
)
self.prenormalization = prenormalization
self.last_layer_query_idx = last_layer_query_idx
self.blocks = nn.ModuleList([])
for layer_idx in range(n_blocks):
layer = nn.ModuleDict(
{
'attention': MultiheadAttention(
d_token=d_token,
n_heads=attention_n_heads,
dropout=attention_dropout,
bias=True,
initialization=attention_initialization,
),
'ffn': FT_Transformer.FFN(
d_token=d_token,
d_hidden=ffn_d_hidden,
bias_first=True,
bias_second=True,
dropout=ffn_dropout,
activation=ffn_activation,
),
'attention_residual_dropout': nn.Dropout(residual_dropout),
'ffn_residual_dropout': nn.Dropout(residual_dropout),
'output': nn.Identity(), # for hooks-based introspection
}
)
if layer_idx or not prenormalization or first_prenormalization:
layer['attention_normalization'] = _make_nn_module(
attention_normalization, d_token
)
layer['ffn_normalization'] = _make_nn_module(ffn_normalization, d_token)
if kv_compression_ratio and self.shared_kv_compression is None:
layer['key_compression'] = make_kv_compression()
if kv_compression_sharing == 'headwise':
layer['value_compression'] = make_kv_compression()
else:
assert (
kv_compression_sharing == 'key-value'
), _INTERNAL_ERROR_MESSAGE
self.blocks.append(layer)
self.head = FT_Transformer.Head(
d_in=d_token,
d_out=d_out,
bias=True,
activation=head_activation, # type: ignore
normalization=head_normalization if prenormalization else 'Identity',
) if projection else nn.Identity()
def _get_kv_compressions(self, layer):
return (
(self.shared_kv_compression, self.shared_kv_compression)
if self.shared_kv_compression is not None
else (layer['key_compression'], layer['value_compression'])
if 'key_compression' in layer and 'value_compression' in layer
else (layer['key_compression'], layer['key_compression'])
if 'key_compression' in layer
else (None, None)
)
def _start_residual(self, layer, stage, x):
assert stage in ['attention', 'ffn'], _INTERNAL_ERROR_MESSAGE
x_residual = x
if self.prenormalization:
norm_key = f'{stage}_normalization'
if norm_key in layer:
x_residual = layer[norm_key](x_residual)
return x_residual
def _end_residual(self, layer, stage, x, x_residual):
assert stage in ['attention', 'ffn'], _INTERNAL_ERROR_MESSAGE
x_residual = layer[f'{stage}_residual_dropout'](x_residual)
x = x + x_residual
if not self.prenormalization:
x = layer[f'{stage}_normalization'](x)
return x
def forward(self, x: Tensor) -> Tensor:
assert (
x.ndim == 3
), 'The input must have 3 dimensions: (n_objects, n_tokens, d_token)'
for layer_idx, layer in enumerate(self.blocks):
layer = cast(nn.ModuleDict, layer)
query_idx = (
self.last_layer_query_idx if layer_idx + 1 == len(self.blocks) else None
)
x_residual = self._start_residual(layer, 'attention', x)
x_residual, _ = layer['attention'](
x_residual if query_idx is None else x_residual[:, query_idx],
x_residual,
*self._get_kv_compressions(layer),
)
if query_idx is not None:
x = x[:, query_idx]
x = self._end_residual(layer, 'attention', x, x_residual)
x_residual = self._start_residual(layer, 'ffn', x)
x_residual = layer['ffn'](x_residual)
x = self._end_residual(layer, 'ffn', x, x_residual)
x = layer['output'](x)
x = self.head(x)
return x