Skip to content

Commit

Permalink
flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
beckynevin committed May 21, 2024
1 parent 5b9c12e commit d701989
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 80 deletions.
18 changes: 8 additions & 10 deletions src/analyze/analyze.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Contains modules to analyze the output checkpoints
# from a trained model and make plots for the paper
import numpy as np
import torch


Expand Down Expand Up @@ -36,20 +35,19 @@ def load_checkpoint(
"""
if model_name[0:3] == "DER":
file_name = (
str(path)
+ f"{model_name}_noise_{noise}_loss_{loss}"
+ f"_COEFF_{COEFF}_epoch_{epoch}"
)
str(path)
+ f"{model_name}_noise_{noise}_loss_{loss}"
+ f"_COEFF_{COEFF}_epoch_{epoch}"
)
if load_rs_chk:
file_name += (f"_rs_{rs}")
file_name += f"_rs_{rs}"
if load_nh_chk:
file_name += (f"_n_hidden_{nh}")
file_name += f"_n_hidden_{nh}"
file_name += ".pt"
elif model_name[0:2] == "DE":
file_name = (
str(path)
+ f"{model_name}_noise_{noise}_beta_{BETA}_"
f"nmodel_{nmodel}_epoch_{epoch}.pt"
str(path) + f"{model_name}_noise_{noise}_beta_{BETA}_"
f"nmodel_{nmodel}_epoch_{epoch}.pt"
)
checkpoint = torch.load(file_name, map_location=device)
return checkpoint
Expand Down
15 changes: 5 additions & 10 deletions src/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ def forward(self, x):
return torch.stack((gamma, nu, alpha, beta), dim=1)


def model_setup_DER(loss_type,
DEVICE,
n_hidden):
def model_setup_DER(loss_type, DEVICE, n_hidden):
# initialize the model from scratch
if loss_type == "SDER":
Layer = SDERLayer
Expand All @@ -82,8 +80,7 @@ def model_setup_DER(loss_type,

# from https://github.com/pasteurlabs/unreasonable_effective_der
# /blob/main/x3_indepth.ipynb
model = torch.nn.Sequential(Model(4, n_hidden),
Layer())
model = torch.nn.Sequential(Model(4, n_hidden), Layer())
model = model.to(DEVICE)
return model, lossFn

Expand Down Expand Up @@ -120,7 +117,6 @@ def model_setup_DE(loss_type, DEVICE):
return model, lossFn



# This following is from PasteurLabs -
# https://github.com/pasteurlabs/unreasonable_effective_der/blob/main/models.py

Expand Down Expand Up @@ -174,14 +170,13 @@ def loss_sder(y, y_pred, coeff):

# define aleatoric and epistemic uncert
u_al = np.sqrt(
(beta.detach().numpy()
* (1 + nu.detach().numpy()))
(beta.detach().numpy() * (1 + nu.detach().numpy()))
/ (alpha.detach().numpy() * nu.detach().numpy())
)
u_ep = 1 / np.sqrt(nu.detach().numpy())

return torch.mean(torch.log(var) + (1.0 + coeff * nu) * error**2 / var), \
u_al, u_ep
return torch.mean(torch.log(var) +
(1.0 + coeff * nu) * error**2 / var), u_al, u_ep


# from martius lab
Expand Down
62 changes: 29 additions & 33 deletions src/scripts/AleatoricInits.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ def parse_args():
# model
# we need some info about the model to run this analysis
# path to save the model results
parser.add_argument("--dir",
default=DefaultsAnalysis["common"]["dir"])
parser.add_argument("--dir", default=DefaultsAnalysis["common"]["dir"])
# now args for model
parser.add_argument(
"--n_models",
Expand Down Expand Up @@ -118,11 +117,13 @@ def parse_args():
# check if args were specified in cli
input_yaml = {
"common": {"dir": args.dir},
"model": {"n_models": args.n_models,
"n_epochs": args.n_epochs,
"BETA": args.BETA,
"COEFF": args.COEFF,
"loss_type": args.loss_type},
"model": {
"n_models": args.n_models,
"n_epochs": args.n_epochs,
"BETA": args.BETA,
"COEFF": args.COEFF,
"loss_type": args.loss_type,
},
"analysis": {
"noise_level_list": args.noise_level_list,
"model_names_list": args.model_names_list,
Expand Down Expand Up @@ -175,10 +176,10 @@ def beta_type(value):
rs_list = [1, 2, 3, 4, 5]
# check that this exists and if not make it
if not os.path.isdir(path_to_out):
print('does not exist, making dir', path_to_out)
print("does not exist, making dir", path_to_out)
os.mkdir(path_to_out)
else:
print('already exists', path_to_out)
print("already exists", path_to_out)
model_name_list = config.get_item("analysis",
"model_names_list",
"Analysis")
Expand All @@ -188,13 +189,11 @@ def beta_type(value):
# make an empty nested dictionary with keys for
# model names followed by noise levels
al_dict = {
model_name: {noise: {rs: [] for rs in rs_list}
for noise in noise_list}
model_name: {noise: {rs: [] for rs in rs_list} for noise in noise_list}
for model_name in model_name_list
}
al_std_dict = {
model_name: {noise: {rs: [] for rs in rs_list}
for noise in noise_list}
model_name: {noise: {rs: [] for rs in rs_list} for noise in noise_list}
for model_name in model_name_list
}
n_epochs = config.get_item("model", "n_epochs", "Analysis")
Expand All @@ -216,13 +215,12 @@ def beta_type(value):
loss=loss_type,
load_rs_chk=True,
rs=rs,
load_nh_chk=False
load_nh_chk=False,
)
# path=path_to_chk)
# things to grab: 'valid_mse' and 'valid_bnll'
_, aleatoric_m, _, a_std = (
_, aleatoric_m, _, a_std = \
chk_module.ep_al_checkpoint_DER(chk)
)
al_dict[model][noise][rs].append(aleatoric_m)
al_std_dict[model][noise][rs].append(a_std)

Expand All @@ -246,7 +244,8 @@ def beta_type(value):
list_vars.append(var_vals)
try:
al_dict[model][noise][nmodels + 1].append(
np.mean(list_vars))
np.mean(list_vars)
)
except KeyError:
continue
# make a two-paneled plot for the different noise levels
Expand All @@ -265,11 +264,11 @@ def beta_type(value):
al = np.array(np.sqrt(al_dict[model][noise][rs]))
else:
al = np.array(al_dict[model][noise][rs])
'''
# it doesn't really make sense to plot the std for the
"""
# it doesn't really make sense to plot the std for the
# case of the DE because each individual model
# makes up one in the ensemble
'''
"""
if model[0:3] == "DER":
al_std = np.array(al_std_dict[model][noise][rs])
ax.fill_between(
Expand All @@ -278,7 +277,7 @@ def beta_type(value):
al + al_std,
color=color_list[n],
alpha=0.1,
edgecolor=None
edgecolor=None,
)
if r == 0:
ax.plot(
Expand All @@ -288,11 +287,8 @@ def beta_type(value):
label=r"$\sigma = $" + str(sigma_list[n]),
)
else:
ax.plot(
range(n_epochs),
al,
color=color_list[n])
ax.axhline(y=sigma_list[n], color=color_list[n], ls='--')
ax.plot(range(n_epochs), al, color=color_list[n])
ax.axhline(y=sigma_list[n], color=color_list[n], ls="--")
ax.set_ylabel("Aleatoric Uncertainty")
ax.set_xlabel("Epoch")
if model[0:3] == "DER":
Expand All @@ -303,12 +299,12 @@ def beta_type(value):
plt.legend()
if config.get_item("analysis", "savefig", "Analysis"):
plt.savefig(
str(path_to_out)
+ "aleatoric_uncertainty_n_epochs_"
+ str(n_epochs)
+ "_n_models_DE_"
+ str(n_models)
+ ".png"
)
str(path_to_out)
+ "aleatoric_uncertainty_n_epochs_"
+ str(n_epochs)
+ "_n_models_DE_"
+ str(n_models)
+ ".png"
)
if config.get_item("analysis", "plot", "Analysis"):
plt.show()
52 changes: 26 additions & 26 deletions src/scripts/AleatoricNHidden.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
def parse_args():
parser = argparse.ArgumentParser(
description="Analyzes the aleatoric uncertainty when the model \
architecture is jittered")
architecture is jittered"
)
# there are three options with the parser:
# 1) Read from a yaml
# 2) Reads from the command line and default file
Expand All @@ -25,8 +26,7 @@ def parse_args():
# model
# we need some info about the model to run this analysis
# path to save the model results
parser.add_argument("--dir",
default=DefaultsAnalysis["common"]["dir"])
parser.add_argument("--dir", default=DefaultsAnalysis["common"]["dir"])
# now args for model
parser.add_argument(
"--n_models",
Expand Down Expand Up @@ -120,11 +120,13 @@ def parse_args():
# check if args were specified in cli
input_yaml = {
"common": {"dir": args.dir},
"model": {"n_models": args.n_models,
"n_epochs": args.n_epochs,
"BETA": args.BETA,
"COEFF": args.COEFF,
"loss_type": args.loss_type},
"model": {
"n_models": args.n_models,
"n_epochs": args.n_epochs,
"BETA": args.BETA,
"COEFF": args.COEFF,
"loss_type": args.loss_type,
},
"analysis": {
"noise_level_list": args.noise_level_list,
"model_names_list": args.model_names_list,
Expand Down Expand Up @@ -178,10 +180,10 @@ def beta_type(value):
rs = 1
# check that this exists and if not make it
if not os.path.isdir(path_to_out):
print('does not exist, making dir', path_to_out)
print("does not exist, making dir", path_to_out)
os.mkdir(path_to_out)
else:
print('already exists', path_to_out)
print("already exists", path_to_out)
model_name_list = config.get_item("analysis",
"model_names_list",
"Analysis")
Expand Down Expand Up @@ -220,7 +222,7 @@ def beta_type(value):
load_rs_chk=True,
rs=rs,
load_nh_chk=True,
nh=nh
nh=nh,
)
# path=path_to_chk)
# things to grab: 'valid_mse' and 'valid_bnll'
Expand Down Expand Up @@ -250,7 +252,8 @@ def beta_type(value):
list_vars.append(var_vals)
try:
al_dict[model][noise][nmodels + 1].append(
np.mean(list_vars))
np.mean(list_vars)
)
except KeyError:
continue
# make a two-paneled plot for the different noise levels
Expand All @@ -277,9 +280,9 @@ def beta_type(value):
al + al_std,
color=color_list[n],
alpha=0.1,
edgecolor=None
edgecolor=None,
)

if h == 0:
ax.plot(
range(n_epochs),
Expand All @@ -288,11 +291,8 @@ def beta_type(value):
label=r"$\sigma = $" + str(sigma_list[n]),
)
else:
ax.plot(
range(n_epochs),
al,
color=color_list[n])
ax.axhline(y=sigma_list[n], color=color_list[n], ls='--')
ax.plot(range(n_epochs), al, color=color_list[n])
ax.axhline(y=sigma_list[n], color=color_list[n], ls="--")
ax.set_ylabel("Aleatoric Uncertainty")
ax.set_xlabel("Epoch")
if model[0:3] == "DER":
Expand All @@ -303,12 +303,12 @@ def beta_type(value):
plt.legend()
if config.get_item("analysis", "savefig", "Analysis"):
plt.savefig(
str(path_to_out)
+ "aleatoric_uncertainty_n_epochs_"
+ str(n_epochs)
+ "_n_models_DE_"
+ str(n_models)
+ ".png"
)
str(path_to_out)
+ "aleatoric_uncertainty_n_epochs_"
+ str(n_epochs)
+ "_n_models_DE_"
+ str(n_models)
+ ".png"
)
if config.get_item("analysis", "plot", "Analysis"):
plt.show()
1 change: 0 additions & 1 deletion src/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def train_DER(
save_n_hidden=False,
n_hidden=64,
verbose=True,

):
# first determine if you even need to run anything
if not save_all_checkpoints and save_final_checkpoint:
Expand Down

0 comments on commit d701989

Please sign in to comment.