-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathutils.py
27 lines (21 loc) · 936 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
def artanh(x):
return 0.5*torch.log((1+x)/(1-x))
def p_exp_map(v):
normv = torch.clamp(torch.norm(v, 2, dim=-1, keepdim=True), min=1e-10)
return torch.tanh(normv)*v/normv
def p_log_map(v):
normv = torch.clamp(torch.norm(v, 2, dim=-1, keepdim=True), 1e-10, 1-1e-5)
return artanh(normv)*v/normv
def full_p_exp_map(x, v):
normv = torch.clamp(torch.norm(v, 2, dim=-1, keepdim=True), min=1e-10)
sqxnorm = torch.clamp(torch.sum(x * x, dim=-1, keepdim=True), 0, 1-1e-5)
y = torch.tanh(normv/(1-sqxnorm)) * v/normv
return p_sum(x, y)
def p_sum(x, y):
sqxnorm = torch.clamp(torch.sum(x * x, dim=-1, keepdim=True), 0, 1-1e-5)
sqynorm = torch.clamp(torch.sum(y * y, dim=-1, keepdim=True), 0, 1-1e-5)
dotxy = torch.sum(x*y, dim=-1, keepdim=True)
numerator = (1+2*dotxy+sqynorm)*x + (1-sqxnorm)*y
denominator = 1 + 2*dotxy + sqxnorm*sqynorm
return numerator/denominator