-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathmodule.py
307 lines (254 loc) · 13.2 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
import math
import copy
import torch
import torch.nn as nn
import torch.nn.functional as func
from typing import Tuple, Optional
from torch import Tensor
class TransformerEncoderLayer(nn.Module):
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
This standard encoder layer is based on the paper "Attention Is All You Need".
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
in a different way during application.
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of intermediate layer, relu or gelu (default=relu).
Examples::
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> out = encoder_layer(src)
"""
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
super(TransformerEncoderLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = func.relu
super(TransformerEncoderLayer, self).__setstate__(state)
def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
r"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
src2, attn = self.self_attn(src, src, src, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src, attn
class TransformerEncoder(nn.Module):
r"""TransformerEncoder is a stack of N encoder layers
Args:
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
norm: the layer normalization component (optional).
Examples::
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
>>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src)
"""
__constants__ = ['norm']
def __init__(self, encoder_layer, num_layers, norm=None):
super(TransformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
output = src
attns = []
for mod in self.layers:
output, attn = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
attns.append(attn)
attns = torch.stack(attns)
if self.norm is not None:
output = self.norm(output)
return output, attns
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation):
if activation == "relu":
return func.relu
elif activation == "gelu":
return func.gelu
raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
class PositionalEncoding(nn.Module):
r"""Inject some information about the relative or absolute position of the tokens
in the sequence. The positional encodings have the same dimension as
the embeddings, so that the two can be summed. Here, we use sine and cosine
functions of different frequencies.
.. math::
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
\text{where pos is the word position and i is the embed idx)
Args:
d_model: the embed dim (required).
dropout: the dropout value (default=0.1).
max_len: the max. length of the incoming sequence (default=5000).
Examples:
>>> pos_encoder = PositionalEncoding(d_model)
"""
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model) # d_model: word embedding size
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # (max_len,) -> (max_len, 1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # (d_model/2,)
'''
probably to prevent from rounding error
e^(idx * (-log 10000 / d_model)) -> (e^(log 10000))^(- idx / d_model) -> 10000^(- idx / d_model) -> 1/(10000^(idx / d_model))
since idx is an even number, it is equal to that in the formula
'''
pe[:, 0::2] = torch.sin(position * div_term) # even number index, (max_len, d_model/2)
pe[:, 1::2] = torch.cos(position * div_term) # odd number index
pe = pe.unsqueeze(0).transpose(0, 1) # (max_len, d_model) -> (1, max_len, d_model) -> (max_len, 1, d_model)
self.register_buffer('pe', pe) # will not be updated by back-propagation, can be called via its name
def forward(self, x):
r"""Inputs of forward function
Args:
x: the sequence fed to the positional encoder model (required).
Shape:
x: [sequence length, batch size, embed dim]
output: [sequence length, batch size, embed dim]
Examples:
>>> output = pos_encoder(x)
"""
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
class MLP(nn.Module):
def __init__(self, emsize=512):
super(MLP, self).__init__()
self.linear1 = nn.Linear(emsize, emsize)
self.linear2 = nn.Linear(emsize, 1)
self.sigmoid = nn.Sigmoid()
self.init_weights()
def init_weights(self):
initrange = 0.1
self.linear1.weight.data.uniform_(-initrange, initrange)
self.linear2.weight.data.uniform_(-initrange, initrange)
self.linear1.bias.data.zero_()
self.linear2.bias.data.zero_()
def forward(self, hidden): # (batch_size, emsize)
mlp_vector = self.sigmoid(self.linear1(hidden)) # (batch_size, emsize)
rating = torch.squeeze(self.linear2(mlp_vector)) # (batch_size,)
return rating
def generate_square_subsequent_mask(total_len):
mask = torch.tril(torch.ones(total_len, total_len)) # (total_len, total_len), lower triangle -> 1.; others 0.
mask = mask == 0 # lower -> False; others True
return mask
def generate_peter_mask(src_len, tgt_len):
total_len = src_len + tgt_len
mask = generate_square_subsequent_mask(total_len)
mask[0, 1] = False # allow to attend for user and item
return mask
class PETER(nn.Module):
def __init__(self, peter_mask, src_len, tgt_len, pad_idx, nuser, nitem, ntoken, emsize, nhead, nhid, nlayers, dropout=0.5):
super(PETER, self).__init__()
self.pos_encoder = PositionalEncoding(emsize, dropout) # emsize: word embedding size
encoder_layers = TransformerEncoderLayer(emsize, nhead, nhid, dropout) # nhid: dim_feedforward, one basic layer, including multi-head attention and FFN
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) # loop over the one above
self.user_embeddings = nn.Embedding(nuser, emsize)
self.item_embeddings = nn.Embedding(nitem, emsize)
self.word_embeddings = nn.Embedding(ntoken, emsize)
self.hidden2token = nn.Linear(emsize, ntoken)
self.recommender = MLP(emsize)
self.ui_len = 2
self.src_len = src_len
self.pad_idx = pad_idx
self.emsize = emsize
if peter_mask:
self.attn_mask = generate_peter_mask(src_len, tgt_len)
else:
self.attn_mask = generate_square_subsequent_mask(src_len + tgt_len)
self.init_weights()
def init_weights(self):
initrange = 0.1
self.user_embeddings.weight.data.uniform_(-initrange, initrange)
self.item_embeddings.weight.data.uniform_(-initrange, initrange)
self.word_embeddings.weight.data.uniform_(-initrange, initrange)
self.hidden2token.weight.data.uniform_(-initrange, initrange)
self.hidden2token.bias.data.zero_()
def predict_context(self, hidden):
context_prob = self.hidden2token(hidden[1]) # (batch_size, ntoken)
log_context_dis = func.log_softmax(context_prob, dim=-1)
return log_context_dis
def predict_rating(self, hidden):
rating = self.recommender(hidden[0]) # (batch_size,)
return rating
def predict_seq(self, hidden):
word_prob = self.hidden2token(hidden[self.src_len:]) # (tgt_len, batch_size, ntoken)
log_word_prob = func.log_softmax(word_prob, dim=-1)
return log_word_prob
def generate_token(self, hidden):
word_prob = self.hidden2token(hidden[-1]) # (batch_size, ntoken)
log_word_prob = func.log_softmax(word_prob, dim=-1)
return log_word_prob
def forward(self, user, item, text, seq_prediction=True, context_prediction=True, rating_prediction=True):
'''
:param user: (batch_size,), torch.int64
:param item: (batch_size,), torch.int64
:param text: (total_len - ui_len, batch_size), torch.int64
:param seq_prediction: bool
:param context_prediction: bool
:param rating_prediction: bool
:return log_word_prob: target tokens (tgt_len, batch_size, ntoken) if seq_prediction=True; the last token (batch_size, ntoken) otherwise.
:return log_context_dis: (batch_size, ntoken) if context_prediction=True; None otherwise.
:return rating: (batch_size,) if rating_prediction=True; None otherwise.
:return attns: (nlayers, batch_size, total_len, total_len)
'''
device = user.device
batch_size = user.size(0)
total_len = self.ui_len + text.size(0) # deal with generation when total_len != src_len + tgt_len
# see nn.MultiheadAttention for attn_mask and key_padding_mask
attn_mask = self.attn_mask[:total_len, :total_len].to(device) # (total_len, total_len)
left = torch.zeros(batch_size, self.ui_len).bool().to(device) # (batch_size, ui_len)
right = text.t() == self.pad_idx # replace pad_idx with True and others with False, (batch_size, total_len - ui_len)
key_padding_mask = torch.cat([left, right], 1) # (batch_size, total_len)
u_src = self.user_embeddings(user.unsqueeze(0)) # (1, batch_size, emsize)
i_src = self.item_embeddings(item.unsqueeze(0)) # (1, batch_size, emsize)
w_src = self.word_embeddings(text) # (total_len - ui_len, batch_size, emsize)
src = torch.cat([u_src, i_src, w_src], 0) # (total_len, batch_size, emsize)
src = src * math.sqrt(self.emsize)
src = self.pos_encoder(src)
hidden, attns = self.transformer_encoder(src, attn_mask, key_padding_mask) # (total_len, batch_size, emsize) vs. (nlayers, batch_size, total_len_tgt, total_len_src)
if rating_prediction:
rating = self.predict_rating(hidden) # (batch_size,)
else:
rating = None
if context_prediction:
log_context_dis = self.predict_context(hidden) # (batch_size, ntoken)
else:
log_context_dis = None
if seq_prediction:
log_word_prob = self.predict_seq(hidden) # (tgt_len, batch_size, ntoken)
else:
log_word_prob = self.generate_token(hidden) # (batch_size, ntoken)
return log_word_prob, log_context_dis, rating, attns