-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathMulti-Head_attention.py
72 lines (57 loc) · 3.3 KB
/
Multi-Head_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def multihead_attention(queries,
keys,
num_units=None,
num_heads=8,
dropout_rate=0,
is_training=True,
causality=False,
scope="multihead_attention",
reuse=None):
with tf.variable_scope(scope, reuse=reuse):
if num_units is None: # set default size for attention size C
num_units = queries.get_shape().as_list()[-1]
# Linear Projections
Q = tf.layers.dense(queries, num_units, activation=tf.nn.relu) # [N, T_q, C]
K = tf.layers.dense(keys, num_units, activation=tf.nn.relu) # [N, T_k, C]
V = tf.layers.dense(keys, num_units, activation=tf.nn.relu) # [N, T_k, C]
# Split and concat
Q_ = tf.concat(tf.split(Q, num_heads, axis=-1), axis=0) # [num_heads * N, T_q, C/num_heads]
K_ = tf.concat(tf.split(K, num_heads, axis=-1), axis=0) # [num_heads * N, T_k, C/num_heads]
V_ = tf.concat(tf.split(V, num_heads, axis=-1), axis=0) # [num_heads * N, T_k, C/num_heads]
# Attention
outputs = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1])) # (num_heads * N, T_q, T_k)
# Scale : outputs = outputs / sqrt( d_k)
outputs = outputs / (K_.get_shape().as_list()[-1] ** 0.5)
# Key Masking
# see : https://github.com/Kyubyong/transformer/issues/3
key_masks = tf.sign(tf.abs(tf.reduce_sum(keys, axis=-1))) # (N, T_k)
key_masks = tf.tile(key_masks, [num_heads, 1]) # (h*N, T_k)
key_masks = tf.tile(tf.expand_dims(key_masks, 1), [1, tf.shape(queries)[1], 1]) # (h*N, T_q, T_k)
paddings = tf.ones_like(outputs) * (-2 ** 32 + 1) # -infinity
outputs = tf.where(tf.equal(key_masks, 0), paddings, outputs) # (h*N, T_q, T_k)
# Causality = Future blinding
if causality:
diag_vals = tf.ones_like(outputs[0, :, :]) # (T_q, T_k)
tril = tf.contrib.linalg.LinearOperatorTriL(diag_vals).to_dense() # (T_q, T_k)
masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(outputs)[0], 1, 1]) # (h*N, T_q, T_k)
paddings = tf.ones_like(masks) * (-2 ** 32 + 1)
outputs = tf.where(tf.equal(masks, 0), paddings, outputs) # (h*N, T_q, T_k)
# Activation: outputs is a weight matrix
outputs = tf.nn.softmax(outputs) # (h*N, T_q, T_k)
# Query Masking
query_masks = tf.sign(tf.abs(tf.reduce_sum(queries, axis=-1))) # (N, T_q)
query_masks = tf.tile(query_masks, [num_heads, 1]) # (h*N, T_q)
query_masks = tf.tile(tf.expand_dims(query_masks, -1), [1, 1, tf.shape(keys)[1]]) # (h*N, T_q, T_k)
outputs *= query_masks # broadcasting. (N, T_q, C)
# dropouts
outputs = tf.layers.dropout(outputs, rate=dropout_rate, training=tf.convert_to_tensor(is_training))
# weighted sum
outputs = tf.matmul(outputs, V_) # ( h*N, T_q, C/h)
# reshape
outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2) # (N, T_q, C)
# residual connection
outputs += queries
# layer normaliztion
outputs = layer_normalization(outputs)
return outputs
# https://github.com/TobiasLee/Text-Classification/blob/master/models/modules/multihead.py