Skip to content

Commit

Permalink
flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
beckynevin committed May 16, 2024
1 parent 0dd7451 commit 16f5ac5
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 28 deletions.
8 changes: 4 additions & 4 deletions src/analyze/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def load_checkpoint(
)
else:
file_name = (
str(path)
+ f"{model_name}_noise_{noise}_loss_{loss}"
+ f"_COEFF_{COEFF}_epoch_{epoch}.pt"
)
str(path)
+ f"{model_name}_noise_{noise}_loss_{loss}"
+ f"_COEFF_{COEFF}_epoch_{epoch}.pt"
)
checkpoint = torch.load(file_name, map_location=device)
elif model_name[0:2] == "DE":
file_name = (
Expand Down
47 changes: 25 additions & 22 deletions src/scripts/AleatoricInits.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,28 +225,29 @@ def beta_type(value):
al_dict[model][noise][rs].append(aleatoric_m)
al_std_dict[model][noise][rs].append(a_std)

elif model[0:2] == "DE":
n_models = config.get_item("model", "n_models", "DE")
for epoch in range(n_epochs):
list_mus = []
list_sigs = []
for nmodels in range(n_models):
chk = chk_module.load_checkpoint(
model,
noise,
epoch,
DEVICE,
path=path_to_chk,
BETA=BETA,
nmodel=nmodels,
)
mu_vals, sig_vals = chk_module.ep_al_checkpoint_DE(chk)
list_mus.append(mu_vals)
list_sigs.append(sig_vals)
al_dict[model][noise][rs].append(np.median(np.mean(list_sigs,
axis=0)))
al_std_dict[model][noise][rs].append(np.std(np.mean(list_sigs,
axis=0)))
if model[0:3] == "DE_":
n_models = config.get_item("model", "n_models", "DE")
for epoch in range(n_epochs):
list_mus = []
list_sigs = []
for nmodels in range(n_models):
chk = chk_module.load_checkpoint(
model,
noise,
epoch,
DEVICE,
path=path_to_chk,
BETA=BETA,
nmodel=nmodels,
)
mu_vals, sig_vals = chk_module.ep_al_checkpoint_DE(chk)
list_mus.append(mu_vals)
list_sigs.append(sig_vals)
try:
al_dict[model][noise][nmodels + 1].append(
np.mean(list_sigs))
except KeyError:
continue
# make a two-paneled plot for the different noise levels
# make one panel per model
# for the noise levels:
Expand All @@ -260,6 +261,7 @@ def beta_type(value):
for n, noise in enumerate(noise_list):
for r, rs in enumerate(rs_list):
al = np.array(np.sqrt(al_dict[model][noise][rs]))
'''
al_std = np.array(np.sqrt(al_std_dict[model][noise][rs]))
ax.fill_between(
range(n_epochs),
Expand All @@ -269,6 +271,7 @@ def beta_type(value):
alpha=0.0,
edgecolor=None
)
'''
if r == 0:
ax.plot(
range(n_epochs),
Expand Down
5 changes: 3 additions & 2 deletions src/utils/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,14 @@
},
"analysis": {
"noise_level_list": ["low", "medium", "high"],
"model_names_list": ["DER"],#, "DE_desiderata_2"],
"model_names_list": ["DER", "DE_desiderata_2"],
# ["DER_desiderata_2", "DE_desiderata_2"]
"plot": True,
"savefig": False,
"verbose": False,
},
"plots": {"color_list": ["#F4D58D", "#339989", "#292F36", "#04A777", "#DF928E"]},
"plots": {"color_list":
["#F4D58D", "#339989", "#292F36", "#04A777", "#DF928E"]},
# Pinks ["#EC4067", "#A01A7D", "#311847"]},
# Blues: ["#8EA8C3", "#406E8E", "#23395B"]},
"metrics_common": {
Expand Down

0 comments on commit 16f5ac5

Please sign in to comment.