-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
51 lines (39 loc) · 1.69 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import numpy as np
from scipy.optimize import curve_fit
from scipy import stats
import torch
import torch.nn.functional as F
import torchsort
import torch.nn as nn
def logistic_func(X, bayta1, bayta2, bayta3, bayta4):
logisticPart = 1 + np.exp(np.negative(np.divide(X - bayta3, np.abs(bayta4))))
yhat = bayta2 + np.divide(bayta1 - bayta2, logisticPart)
return yhat
def fit_function(y_label, y_output):
beta = [np.max(y_label), np.min(y_label), np.mean(y_output), 0.5]
popt, _ = curve_fit(logistic_func, y_output, \
y_label, p0=beta, maxfev=100000000)
y_output_logistic = logistic_func(y_output, *popt)
return y_output_logistic
def performance_fit(y_label, y_output):
y_output_logistic = fit_function(y_label, y_output)
PLCC = stats.pearsonr(y_output_logistic, y_label)[0]
SRCC = stats.spearmanr(y_output, y_label)[0]
KRCC = stats.stats.kendalltau(y_output, y_label)[0]
RMSE = np.sqrt(((y_output_logistic-y_label) ** 2).mean())
return PLCC, SRCC, KRCC, RMSE
def performance_no_fit(y_label, y_output):
PLCC = stats.pearsonr(y_output, y_label)[0]
SRCC = stats.spearmanr(y_output, y_label)[0]
KRCC = stats.stats.kendalltau(y_output, y_label)[0]
RMSE = np.sqrt(((y_output-y_label) ** 2).mean())
return PLCC, SRCC, KRCC, RMSE
def plcc_loss(y_pred, y):
sigma_hat, m_hat = torch.std_mean(y_pred, unbiased=False)
y_pred = (y_pred - m_hat) / (sigma_hat + 1e-8)
sigma, m = torch.std_mean(y, unbiased=False)
y = (y - m) / (sigma + 1e-8)
loss0 = torch.nn.functional.mse_loss(y_pred, y) / 4
rho = torch.mean(y_pred * y)
loss1 = torch.nn.functional.mse_loss(rho * y_pred, y) / 4
return ((loss0 + loss1) / 2).float()