-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsimilarity.py
46 lines (35 loc) · 1.28 KB
/
similarity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import sklearn
def pairwise_distance(a, squared=False):
"""Computes the pairwise distance matrix with numerical stability."""
pairwise_distances_squared = torch.add(
a.pow(2).sum(dim=1, keepdim=True).expand(a.size(0), -1),
torch.t(a).pow(2).sum(dim=0, keepdim=True).expand(a.size(0), -1)
) - 2 * (
torch.mm(a, torch.t(a))
)
# Deal with numerical inaccuracies. Set small negatives to zero.
pairwise_distances_squared = torch.clamp(
pairwise_distances_squared, min=0.0
)
# Get the mask where the zero distances are at.
error_mask = torch.le(pairwise_distances_squared, 0.0)
# Optionally take the sqrt.
if squared:
pairwise_distances = pairwise_distances_squared
else:
pairwise_distances = torch.sqrt(
pairwise_distances_squared + error_mask.float() * 1e-16
)
# Undo conditionally adding 1e-16.
pairwise_distances = torch.mul(
pairwise_distances,
(error_mask == False).float()
)
# Explicitly set diagonals to zero.
mask_offdiagonals = 1 - torch.eye(
*pairwise_distances.size(),
device=pairwise_distances.device
)
pairwise_distances = torch.mul(pairwise_distances, mask_offdiagonals)
return pairwise_distances