-
Notifications
You must be signed in to change notification settings - Fork 983
/
Copy pathSimplE.py
executable file
·55 lines (46 loc) · 1.92 KB
/
SimplE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import torch.nn as nn
from .Model import Model
class SimplE(Model):
def __init__(self, ent_tot, rel_tot, dim = 100):
super(SimplE, self).__init__(ent_tot, rel_tot)
self.dim = dim
self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim)
self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim)
self.rel_inv_embeddings = nn.Embedding(self.rel_tot, self.dim)
nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
nn.init.xavier_uniform_(self.rel_inv_embeddings.weight.data)
def _calc_avg(self, h, t, r, r_inv):
return (torch.sum(h * r * t, -1) + torch.sum(h * r_inv * t, -1))/2
def _calc_ingr(self, h, r, t):
return torch.sum(h * r * t, -1)
def forward(self, data):
batch_h = data['batch_h']
batch_t = data['batch_t']
batch_r = data['batch_r']
h = self.ent_embeddings(batch_h)
t = self.ent_embeddings(batch_t)
r = self.rel_embeddings(batch_r)
r_inv = self.rel_inv_embeddings(batch_r)
score = self._calc_avg(h, t, r, r_inv)
return score
def regularization(self, data):
batch_h = data['batch_h']
batch_t = data['batch_t']
batch_r = data['batch_r']
h = self.ent_embeddings(batch_h)
t = self.ent_embeddings(batch_t)
r = self.rel_embeddings(batch_r)
r_inv = self.rel_inv_embeddings(batch_r)
regul = (torch.mean(h ** 2) + torch.mean(t ** 2) + torch.mean(r ** 2) + torch.mean(r_inv ** 2)) / 4
return regul
def predict(self, data):
batch_h = data['batch_h']
batch_t = data['batch_t']
batch_r = data['batch_r']
h = self.ent_embeddings(batch_h)
t = self.ent_embeddings(batch_t)
r = self.rel_embeddings(batch_r)
score = -self._calc_ingr(h, r, t)
return score.cpu().data.numpy()