Skip to content

Commit

Permalink
argparse should work
Browse files Browse the repository at this point in the history
  • Loading branch information
beckynevin committed Apr 9, 2024
1 parent 35fb574 commit 83703cd
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 41 deletions.
7 changes: 4 additions & 3 deletions src/scripts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@ def forward(self, x):
return torch.stack((gamma, nu, alpha, beta), dim=1)


def model_setup_DER(DER_type, DEVICE):
def model_setup_DER(loss_type, DEVICE):
print('loss type', loss_type, type(loss_type))
# initialize the model from scratch
if DER_type == "SDER":
if loss_type == "SDER":
Layer = SDERLayer
# initialize our loss function
lossFn = loss_sder
if DER_type == "DER":
if loss_type == "DER":
Layer = DERLayer
# initialize our loss function
lossFn = loss_der
Expand Down
225 changes: 191 additions & 34 deletions src/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,51 +14,69 @@ def train_DER(
INIT_LR,
DEVICE,
COEFF,
DER_type,
model_name,
EPOCHS=40,
save_checkpoints=False,
loss_type,
wd,
model_name="DER",
EPOCHS=100,
path_to_model="models/",
plot=False,
verbose=True,
save_all_checkpoints=False,
save_final_checkpoint=False,
overwrite_final_checkpoint=False,
plot=True,
savefig=True,
verbose=True
):
# first determine if you even need to run anything
if not save_all_checkpoints and save_final_checkpoint:
# option to skip running the model if you don't care about
# saving all checkpoints and only want to save the final
final_chk = (
path_to_model
+ str(model_name)
+ "_loss_"
+ str(loss_type)
+ "_epoch_"
+ str(EPOCHS - 1)
+ ".pt"
)
if verbose:
print("final chk", final_chk)
# check if the final epoch checkpoint already exists
print(glob.glob(final_chk))
if glob.glob(final_chk):
print("final model already exists")
if overwrite_final_checkpoint:
print("going to overwrite final checkpoint")
else:
print("not overwriting, exiting")
return
else:
print("model does not exist yet, going to save")
# measure how long training is going to take
if verbose:
print("[INFO] training the network...")
print("saving checkpoints?")
print(save_checkpoints)
print("saving all checkpoints?")
print(save_all_checkpoints)
print("saving final checkpoint?")
print(save_final_checkpoint)
print("overwriting final checkpoint if its already there?")
print(overwrite_final_checkpoint)
print(f"saving here: {path_to_model}")
print(f"model name: {model_name}")

startTime = time.time()
start_epoch = 0
"""
# Find last epoch saved
if save_checkpoints:
print(glob.glob(path_to_model + "/" + str(model_name) + "*"))
list_models_run = []
for file in glob.glob(path_to_model + "/" + str(model_name) + "*"):
list_models_run.append(
float(str.split(str(str.split(file,
model_name + "_")[1]), ".")[0])
)
if list_models_run:
start_epoch = max(list_models_run) + 1
else:
start_epoch = 0
else:
start_epoch = 0
print("starting here", start_epoch)
"""

best_loss = np.inf # init to infinity
model, lossFn = models.model_setup_DER(DER_type, DEVICE)
model, lossFn = models.model_setup_DER(loss_type, DEVICE)
if verbose:
print('model is', model, 'lossfn', lossFn)

opt = torch.optim.Adam(model.parameters(), lr=INIT_LR)

# loop over our epochs
for e in range(0, EPOCHS):
if plot:
if plot or savefig:
plt.clf()
fig, (ax1, ax2) = plt.subplots(
2, 1, figsize=(8, 6), gridspec_kw={"height_ratios": [3, 1]}
Expand All @@ -84,8 +102,8 @@ def train_DER(

pred = model(x)
loss = lossFn(pred, y, COEFF)
if plot and (e % 5 == 0):
if i == 0:
if plot or savefig:
if (e % (EPOCHS - 1) == 0) and (e != 0):
pred_loader_0 = pred[:, 0].flatten().detach().numpy()
y_loader_0 = y.detach().numpy()
ax1.scatter(
Expand All @@ -108,10 +126,12 @@ def train_DER(
xycoords="axes fraction",
color="black",
)
'''
else:
ax1.scatter(y,
pred[:, 0].flatten().detach().numpy(),
color="grey")
'''
loss_this_epoch.append(loss[0].item())

# zero out the gradients
Expand All @@ -123,6 +143,112 @@ def train_DER(
# optimizer takes a step based on the gradients of the parameters
# here, its taking a step for every batch
opt.step()
if (plot or savefig) and (e % (EPOCHS - 1) == 0) and (e != 0):
ax1.plot(range(0, 1000),
range(0, 1000),
color="black",
ls="--")
if loss_type == "no_var_loss":
ax1.scatter(
y_val,
y_pred.flatten().detach().numpy(),
color="#F45866",
edgecolor="black",
zorder=100,
label="validation dtata",
)
else:
ax1.errorbar(
y_val,
y_pred[:, 0].flatten().detach().numpy(),
yerr=np.sqrt(y_pred[:, 1].flatten().detach().numpy()),
linestyle="None",
color="black",
capsize=2,
zorder=100,
)
ax1.scatter(
y_val,
y_pred[:, 0].flatten().detach().numpy(),
color="#9CD08F",
s=5,
zorder=101,
label="validation data",
)

# add residual plot
residuals = y_pred[:, 0].flatten().detach().numpy() - y_val
ax2.errorbar(
y_val,
residuals,
yerr=np.sqrt(y_pred[:, 1].flatten().detach().numpy()),
linestyle="None",
color="black",
capsize=2,
)
ax2.scatter(y_val, residuals, color="#9B287B", s=5, zorder=100)
ax2.axhline(0, color="black", linestyle="--", linewidth=1)
ax2.set_ylabel("Residuals")
ax2.set_xlabel("True Value")
# add annotion for loss value
if loss_type == "bnll_loss":
ax1.annotate(
r"$\beta = $"
+ str(round(beta_epoch, 2))
+ "\n"
+ str(loss_type)
+ " = "
+ str(round(loss, 2))
+ "\n"
+ r"MSE = "
+ str(round(mse, 2)),
xy=(0.73, 0.1),
xycoords="axes fraction",
bbox=dict(
boxstyle="round,pad=0.5",
facecolor="lightgrey",
alpha=0.5
),
)

else:
ax1.annotate(
str(loss_type)
+ " = "
+ str(round(loss, 2))
+ "\n"
+ r"MSE = "
+ str(round(mse, 2)),
xy=(0.73, 0.1),
xycoords="axes fraction",
bbox=dict(
boxstyle="round,pad=0.5",
facecolor="lightgrey",
alpha=0.5
),
)
ax1.set_ylabel("Prediction")
ax1.set_title("Epoch " + str(e))
ax1.set_xlim([0, 1000])
ax1.set_ylim([0, 1000])
ax1.legend()
if savefig:
# ax1.errorbar(200, 600, yerr=5,
# color='red', capsize=2)
plt.savefig(
str(wd)
+ "images/animations/"
+ str(model_name)
+ "_loss_"
+ str(loss_type)
+ "_epoch_"
+ str(epoch)
+ ".png"
)
if plot:
plt.show()
plt.close()
'''
if plot and (e % 5 == 0):
ax1.set_ylabel("prediction")
ax1.set_title("Epoch " + str(e))
Expand All @@ -136,6 +262,7 @@ def train_DER(
plt.show()
plt.close()
'''
model.eval()
y_pred = model(torch.Tensor(x_val))
loss = lossFn(y_pred, torch.Tensor(y_val), COEFF)
Expand All @@ -155,7 +282,7 @@ def train_DER(
# best_weights = copy.deepcopy(model.state_dict())
# print('validation loss', mse)

if save_checkpoints:
if save_all_checkpoints:

torch.save(
{
Expand All @@ -170,9 +297,39 @@ def train_DER(
"std_u_al_validation": std_u_al_val,
"std_u_ep_validation": std_u_ep_val,
},
path_to_model + "/" + str(model_name)
+ "_epoch_" + str(epoch) + ".pt",
str(wd)
+ "models/"
+ str(model_name)
+ "_loss_"
+ str(loss_type)
+ "_epoch_"
+ str(epoch)
+ ".pt",
)
if save_final_checkpoint and (e % (EPOCHS - 1) == 0) and (e != 0):
# option to just save final epoch
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": opt.state_dict(),
"train_loss": np.mean(loss_this_epoch),
"valid_loss": NIGloss_val,
"valid_mse": mse,
"med_u_al_validation": med_u_al_val,
"med_u_ep_validation": med_u_ep_val,
"std_u_al_validation": std_u_al_val,
"std_u_ep_validation": std_u_ep_val,
},
str(wd)
+ "models/"
+ str(model_name)
+ "_loss_"
+ str(loss_type)
+ "_epoch_"
+ str(epoch)
+ ".pt",
)
endTime = time.time()
if verbose:
print("start at", startTime, "end at", endTime)
Expand Down
4 changes: 0 additions & 4 deletions test/test_DeepEnsemble.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import sys
import pytest
import torch
import numpy as np
import sbi
import os
import subprocess
import tempfile
import shutil
import unittest

# flake8: noqa
sys.path.append("..")
Expand Down

0 comments on commit 83703cd

Please sign in to comment.