-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy path4.0 -Luong_attention.py
43 lines (32 loc) · 1.55 KB
/
4.0 -Luong_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class Attn(nn.Module):
def __init__(self, method, hidden_size, max_length=MAX_LENGTH):
super(Attn, self).__init__()
self.method = method
self.hidden_size = hidden_size
if self.method == 'general':
self.attn = nn.Linear(self.hidden_size, hidden_size)
elif self.method == 'concat':
self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
self.other = nn.Parameter(torch.FloatTensor(1, hidden_size))
def forward(self, hidden, encoder_outputs):
seq_len = len(encoder_outputs)
# Create variable to store attention energies
attn_energies = Variable(torch.zeros(seq_len)) # B x 1 x S
if USE_CUDA: attn_energies = attn_energies.cuda()
# Calculate energies for each encoder output
for i in range(seq_len):
attn_energies[i] = self.score(hidden, encoder_outputs[i])
# Normalize energies to weights in range 0 to 1, resize to 1 x 1 x seq_len
return F.softmax(attn_energies).unsqueeze(0).unsqueeze(0)
def score(self, hidden, encoder_output):
if self.method == 'dot':
energy = hidden.dot(encoder_output)
return energy
elif self.method == 'general':
energy = self.attn(encoder_output)
energy = hidden.dot(energy)
return energy
elif self.method == 'concat':
energy = self.attn(torch.cat((hidden, encoder_output), 1))
energy = self.other.dot(energy)
return energy