From ec3c92c6a5466d827c088198fc7233c684a26112 Mon Sep 17 00:00:00 2001 From: beckynevin Date: Fri, 17 May 2024 14:02:10 -0600 Subject: [PATCH] modified to be DER_wst with a scaling by this term --- src/models/models.py | 4 ++-- src/scripts/Aleatoric.py | 2 +- src/utils/defaults.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/models/models.py b/src/models/models.py index 50619c3..471856d 100644 --- a/src/models/models.py +++ b/src/models/models.py @@ -189,7 +189,7 @@ def loss_der(y, y_pred, coeff): gamma, nu, alpha, beta = y[:, 0], y[:, 1], y[:, 2], y[:, 3] error = gamma - y_pred omega = 2.0 * beta * (1.0 + nu) - + w_st = torch.sqrt(beta * (1 + nu) / (alpha * nu)) # define aleatoric and epistemic uncert u_al = np.sqrt( beta.detach().numpy() @@ -204,7 +204,7 @@ def loss_der(y, y_pred, coeff): + (alpha + 0.5) * torch.log(error**2 * nu + omega) + torch.lgamma(alpha) - torch.lgamma(alpha + 0.5) - + coeff * torch.abs(error) * (2.0 * nu + alpha) + + (coeff * torch.abs(error / w_st) * (2.0 * nu + alpha)) ), u_al, u_ep, diff --git a/src/scripts/Aleatoric.py b/src/scripts/Aleatoric.py index 74e0b63..c1cdfd2 100644 --- a/src/scripts/Aleatoric.py +++ b/src/scripts/Aleatoric.py @@ -290,7 +290,7 @@ def beta_type(value): ax.set_title("Deep Evidential Regression") elif model[0:2] == "DE": ax.set_title("Deep Ensemble (100 models)") - ax.set_ylim([0, 14]) + ax.set_ylim([0, 11]) plt.legend() if config.get_item("analysis", "savefig", "Analysis"): plt.savefig( diff --git a/src/utils/defaults.py b/src/utils/defaults.py index d1ae775..21d6de4 100644 --- a/src/utils/defaults.py +++ b/src/utils/defaults.py @@ -100,7 +100,7 @@ }, "analysis": { "noise_level_list": ["low", "medium", "high"], - "model_names_list": ["DER", "DE_desiderata_2"], + "model_names_list": ["DER_wst", "DE_desiderata_2"], # ["DER_desiderata_2", "DE_desiderata_2"] "plot": True, "savefig": False,