Skip to content

Commit

Permalink
Merge pull request #97 from deepskies/issue/config_multiple_initializ…
Browse files Browse the repository at this point in the history
…ations

Issue/config multiple initializations
  • Loading branch information
beckynevin authored May 16, 2024
2 parents d4ed220 + 862d079 commit 52a8436
Show file tree
Hide file tree
Showing 7 changed files with 465 additions and 56 deletions.
19 changes: 14 additions & 5 deletions src/analyze/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def load_checkpoint(
nmodel=None,
COEFF=0.5,
loss="SDER",
load_rs_chk=False,
rs=42
):
"""
Load PyTorch model checkpoint from a .pt file.
Expand All @@ -31,11 +33,18 @@ def load_checkpoint(
:return: Loaded model
"""
if model_name[0:3] == "DER":
file_name = (
str(path)
+ f"{model_name}_noise_{noise}_loss_{loss}"
+ f"_COEFF_{COEFF}_epoch_{epoch}.pt"
)
if load_rs_chk:
file_name = (
str(path)
+ f"{model_name}_noise_{noise}_loss_{loss}"
+ f"_COEFF_{COEFF}_epoch_{epoch}_rs_{rs}.pt"
)
else:
file_name = (
str(path)
+ f"{model_name}_noise_{noise}_loss_{loss}"
+ f"_COEFF_{COEFF}_epoch_{epoch}.pt"
)
checkpoint = torch.load(file_name, map_location=device)
elif model_name[0:2] == "DE":
file_name = (
Expand Down
306 changes: 306 additions & 0 deletions src/scripts/AleatoricInits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
import os
import yaml
import argparse
import numpy as np
import torch
import matplotlib.pyplot as plt
from utils.config import Config
from utils.defaults import DefaultsAnalysis
from data.data import DataPreparation
from analyze.analyze import AggregateCheckpoints


def parse_args():
parser = argparse.ArgumentParser(description="data handling module")
# there are three options with the parser:
# 1) Read from a yaml
# 2) Reads from the command line and default file
# and dumps to yaml

# option to pass name of config
parser.add_argument("--config", "-c", default=None)

# model
# we need some info about the model to run this analysis
# path to save the model results
parser.add_argument("--dir",
default=DefaultsAnalysis["common"]["dir"])
# now args for model
parser.add_argument(
"--n_models",
type=int,
default=DefaultsAnalysis["model"]["n_models"],
help="Number of MVEs in the ensemble",
)
parser.add_argument(
"--BETA",
type=beta_type,
required=False,
default=DefaultsAnalysis["model"]["BETA"],
help="If loss_type is bnn_loss, specify a beta as a float or \
there are string options: linear_decrease, \
step_decrease_to_0.5, and step_decrease_to_1.0",
)
parser.add_argument(
"--COEFF",
type=float,
required=False,
default=DefaultsAnalysis["model"]["COEFF"],
help="COEFF for DER",
)
parser.add_argument(
"--loss_type",
type=str,
required=False,
default=DefaultsAnalysis["model"]["loss_type"],
help="loss_type for DER, either SDER or DER",
)
parser.add_argument(
"--noise_level_list",
type=list,
required=False,
default=DefaultsAnalysis["analysis"]["noise_level_list"],
help="Noise levels to compare",
)
parser.add_argument(
"--model_names_list",
type=list,
required=False,
default=DefaultsAnalysis["analysis"]["model_names_list"],
help="Beginning of name for saved checkpoints and figures",
)
parser.add_argument(
"--n_epochs",
type=int,
required=False,
default=DefaultsAnalysis["model"]["n_epochs"],
help="number of epochs",
)
parser.add_argument(
"--plot",
action="store_true",
default=DefaultsAnalysis["analysis"]["plot"],
help="option to plot in notebook",
)
parser.add_argument(
"--color_list",
type=list,
default=DefaultsAnalysis["plots"]["color_list"],
help="list of named or hexcode colors to use for the noise levels",
)
parser.add_argument(
"--savefig",
action="store_true",
default=DefaultsAnalysis["analysis"]["savefig"],
help="option to save a figure of the true and predicted values",
)
parser.add_argument(
"--verbose",
action="store_true",
default=DefaultsAnalysis["analysis"]["verbose"],
help="verbose option for train",
)
args = parser.parse_args()
args = parser.parse_args()
if args.config is not None:
print("Reading settings from config file", args.config)
config = Config(args.config)

else:
temp_config = DefaultsAnalysis["common"]["temp_config"]
print(
"Reading settings from cli and default, \
dumping to temp config: ",
temp_config,
)
os.makedirs(os.path.dirname(temp_config), exist_ok=True)

# check if args were specified in cli
input_yaml = {
"common": {"dir": args.dir},
"model": {"n_models": args.n_models,
"n_epochs": args.n_epochs,
"BETA": args.BETA,
"COEFF": args.COEFF,
"loss_type": args.loss_type},
"analysis": {
"noise_level_list": args.noise_level_list,
"model_names_list": args.model_names_list,
"plot": args.plot,
"savefig": args.savefig,
"verbose": args.verbose,
},
"plots": {"color_list": args.color_list},
# "metrics": {key: {} for key in args.metrics},
}

yaml.dump(input_yaml, open(temp_config, "w"))
config = Config(temp_config)

return config
# return parser.parse_args()


def beta_type(value):
if isinstance(value, float):
return value
elif value.lower() == "linear_decrease":
return value
elif value.lower() == "step_decrease_to_0.5":
return value
elif value.lower() == "step_decrease_to_1.0":
return value
else:
raise argparse.ArgumentTypeError(
"BETA must be a float or one of 'linear_decrease', \
'step_decrease_to_0.5', 'step_decrease_to_1.0'"
)


if __name__ == "__main__":
config = parse_args()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
noise_list = config.get_item("analysis", "noise_level_list", "Analysis")
color_list = config.get_item("plots", "color_list", "Analysis")
BETA = config.get_item("model", "BETA", "Analysis")
COEFF = config.get_item("model", "COEFF", "Analysis")
loss_type = config.get_item("model", "loss_type", "Analysis")
sigma_list = []
for noise in noise_list:
sigma_list.append(DataPreparation.get_sigma(noise))
root_dir = config.get_item("common", "dir", "Analysis")
path_to_chk = root_dir + "checkpoints/"
path_to_out = root_dir + "analysis/"
# this needs to be redone
rs_list = [1, 2, 3, 4, 5]
# check that this exists and if not make it
if not os.path.isdir(path_to_out):
print('does not exist, making dir', path_to_out)
os.mkdir(path_to_out)
else:
print('already exists', path_to_out)
model_name_list = config.get_item("analysis",
"model_names_list",
"Analysis")
print("model list", model_name_list)
print("noise list", noise_list)
chk_module = AggregateCheckpoints()
# make an empty nested dictionary with keys for
# model names followed by noise levels
al_dict = {
model_name: {noise: {rs: [] for rs in rs_list}
for noise in noise_list}
for model_name in model_name_list
}
al_std_dict = {
model_name: {noise: {rs: [] for rs in rs_list}
for noise in noise_list}
for model_name in model_name_list
}
n_epochs = config.get_item("model", "n_epochs", "Analysis")
for model in model_name_list:
for noise in noise_list:
for rs in rs_list:

# append a noise key
# now run the analysis on the resulting checkpoints
if model[0:3] == "DER":
for epoch in range(n_epochs):
chk = chk_module.load_checkpoint(
model,
noise,
epoch,
DEVICE,
path=path_to_chk,
COEFF=COEFF,
loss=loss_type,
load_rs_chk=True,
rs=rs
)
# path=path_to_chk)
# things to grab: 'valid_mse' and 'valid_bnll'
epistemic_m, aleatoric_m, e_std, a_std = (
chk_module.ep_al_checkpoint_DER(chk)
)
al_dict[model][noise][rs].append(aleatoric_m)
al_std_dict[model][noise][rs].append(a_std)

if model[0:3] == "DE_":
n_models = config.get_item("model", "n_models", "DE")
for epoch in range(n_epochs):
list_mus = []
list_sigs = []
for nmodels in range(n_models):
chk = chk_module.load_checkpoint(
model,
noise,
epoch,
DEVICE,
path=path_to_chk,
BETA=BETA,
nmodel=nmodels,
)
mu_vals, sig_vals = chk_module.ep_al_checkpoint_DE(chk)
list_mus.append(mu_vals)
list_sigs.append(sig_vals)
try:
al_dict[model][noise][nmodels + 1].append(
np.mean(list_sigs))
except KeyError:
continue
# make a two-paneled plot for the different noise levels
# make one panel per model
# for the noise levels:
plt.clf()
fig = plt.figure(figsize=(10, 4))
# try this instead with a fill_between method
for i, model in enumerate(model_name_list):
ax = fig.add_subplot(1, len(model_name_list), i + 1)
# Your plotting code for each model here
ax.set_title(model) # Set title for each subplot
for n, noise in enumerate(noise_list):
for r, rs in enumerate(rs_list):
al = np.array(np.sqrt(al_dict[model][noise][rs]))
'''
al_std = np.array(np.sqrt(al_std_dict[model][noise][rs]))
ax.fill_between(
range(n_epochs),
al - al_std,
al + al_std,
color=color_list[n],
alpha=0.0,
edgecolor=None
)
'''
if r == 0:
ax.plot(
range(n_epochs),
al,
color=color_list[n],
label=r"$\sigma = $" + str(sigma_list[n]),
)
else:
ax.plot(
range(n_epochs),
al,
color=color_list[n])
ax.axhline(y=sigma_list[n], color=color_list[n], ls='--')
ax.set_ylabel("Aleatoric Uncertainty")
ax.set_xlabel("Epoch")
if model[0:3] == "DER":
ax.set_title("Deep Evidential Regression")
elif model[0:2] == "DE":
ax.set_title("Deep Ensemble (100 models)")
ax.set_ylim([0, 14])
plt.legend()
if config.get_item("analysis", "savefig", "Analysis"):
plt.savefig(
str(path_to_out)
+ "aleatoric_uncertainty_n_epochs_"
+ str(n_epochs)
+ "_n_models_DE_"
+ str(n_models)
+ ".png"
)
if config.get_item("analysis", "plot", "Analysis"):
plt.show()
18 changes: 18 additions & 0 deletions src/scripts/DeepEvidentialRegression.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,18 @@ def parse_args():
default=DefaultsDER["model"]["savefig"],
help="option to save a figure of the true and predicted values",
)
parser.add_argument(
"--save_chk_random_seed_init",
action="store_true",
default=DefaultsDER["model"]["save_chk_random_seed_init"],
help="option to save the chk with a random seed",
)
parser.add_argument(
"--rs",
type=int,
default=DefaultsDER["model"]["rs"],
help="define a random seed to save",
)
parser.add_argument(
"--verbose",
action="store_true",
Expand Down Expand Up @@ -201,6 +213,8 @@ def parse_args():
"overwrite_final_checkpoint": args.overwrite_final_checkpoint,
"plot": args.plot,
"savefig": args.savefig,
"save_chk_random_seed_init": args.save_chk_random_seed_init,
"rs": args.rs,
"verbose": args.verbose,
},
"data": {
Expand Down Expand Up @@ -307,5 +321,9 @@ def parse_args():
),
plot=config.get_item("model", "plot", "DER"),
savefig=config.get_item("model", "savefig", "DER"),
set_and_save_rs=config.get_item("model",
"save_chk_random_seed_init",
"DER"),
rs=config.get_item("model", "rs", "DER"),
verbose=config.get_item("model", "verbose", "DER"),
)
Loading

0 comments on commit 52a8436

Please sign in to comment.