Skip to content

Commit

Permalink
train DE mods to accept beta and other args, option to save only fina…
Browse files Browse the repository at this point in the history
…l chk
  • Loading branch information
beckynevin committed Mar 20, 2024
1 parent 1d5c25b commit 14f9fa9
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 37 deletions.
3 changes: 1 addition & 2 deletions src/scripts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,9 @@ def model_setup_DE(loss_type, DEVICE):#, INIT_LR=0.001):
#model = de_var().to(DEVICE)
Layer = MuVarLayer
lossFn = loss_bnll
#opt = torch.optim.Adam(model.parameters(), lr=INIT_LR)
model = torch.nn.Sequential(Model(2), Layer())
model = model.to(DEVICE)
return model, lossFn#, opt
return model, lossFn


class de_no_var(nn.Module):
Expand Down
129 changes: 94 additions & 35 deletions src/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,11 @@ def train_DE(
loss_type,
n_models,
model_name="DE",
BETA=None,
EPOCHS=100,
path_to_model="models/",
save_checkpoints=False,
save_all_checkpoints=False,
save_final_checkpoint=False,
plot=True,
savefig=True,
verbose=True,
Expand All @@ -201,7 +203,7 @@ def train_DE(

'''
# Find last epoch saved
if save_checkpoints is True:
if save_checkpoints:
print('looking for saved checkpts',
glob.glob("models/*" + model_name + "*"))
Expand Down Expand Up @@ -249,7 +251,7 @@ def train_DE(
print("epoch", epoch, round(e / EPOCHS, 2))

loss_this_epoch = []
if plot is True:
if plot:
plt.clf()
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6),
gridspec_kw={'height_ratios': [3, 1]}
Expand Down Expand Up @@ -282,13 +284,13 @@ def train_DE(
beta_epoch = 0.5
# 1 - e / EPOCHS # this one doesn't work great
'''
beta_epoch = 1 - e / EPOCHS
#beta_epoch = 1.0
#beta_epoch = 1 - e / EPOCHS
beta_epoch = BETA
loss = lossFn(pred[:, 0].flatten(),
pred[:, 1].flatten(),
y,
beta=beta_epoch)
if plot is True:
if plot:
if (e % (EPOCHS-1) == 0) and (e != 0):
if loss_type == "no_var_loss":
ax1.scatter(y, pred.flatten().detach().numpy(),
Expand Down Expand Up @@ -356,7 +358,7 @@ def train_DE(
print("new best loss", loss, "in epoch", epoch)
# best_weights = copy.deepcopy(model.state_dict())
# print('validation loss', mse)
if (plot is True) and (e % (EPOCHS-1) == 0) and (e != 0):
if (plot or savefig) and (e % (EPOCHS-1) == 0) and (e != 0):
ax1.plot(range(0, 1000),
range(0, 1000),
color='black',
Expand Down Expand Up @@ -405,8 +407,6 @@ def train_DE(
ax2.axhline(0, color='black', linestyle='--', linewidth=1)
ax2.set_ylabel("Residuals")
ax2.set_xlabel("True Value")


# add annotion for loss value
if loss_type == "bnll_loss":
ax1.annotate(r'$\beta = $' +
Expand All @@ -421,44 +421,103 @@ def train_DE(
ax1.annotate(str(loss_type) + ' = ' + str(round(loss,2)) + '\n' + r'MSE = ' + str(round(mse,2)),
xy=(0.73, 0.1),
xycoords='axes fraction',
bbox=dict(boxstyle='round,pad=0.5', facecolor='lightgrey', alpha=0.5))



bbox=dict(boxstyle='round,pad=0.5', facecolor='lightgrey', alpha=0.5))
ax1.set_ylabel("Prediction")
ax1.set_title("Epoch " + str(e))
ax1.set_xlim([0, 1000])
ax1.set_ylim([0, 1000])
ax1.legend()
if savefig is True:
if savefig:
ax1.errorbar(200, 600, yerr=5,
color='red', capsize=2)
plt.savefig("../images/animations/" +
str(model_name) + "_nmodel_" +
str(m) + "_beta_" + str(beta_epoch) +
"_epoch_" + str(epoch) + ".png")
plt.show()
if plot:
plt.show()
plt.close()

if save_checkpoints is True:

torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": opt.state_dict(),
"train_loss": np.mean(loss_this_epoch),
"valid_loss": loss,
"valid_mean": y_pred[:, 0].flatten(),
"valid_sigma": y_pred[:, 1].flatten(),
"x_val": x_val,
"y_val": y_val,
},
path_to_model + "/" +
str(model_name) + "_nmodel_" +
str(m) + "_epoch_" +
str(epoch) + ".pt",
)
if save_all_checkpoints:
if loss_type == "bnll_loss":
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": opt.state_dict(),
"train_loss": np.mean(loss_this_epoch),
"valid_loss": loss,
"valid_mse": mse,
"valid_mean": y_pred[:, 0].flatten(),
"valid_sigma": y_pred[:, 1].flatten(),
"x_val": x_val,
"y_val": y_val,
},
path_to_model + "/" +
str(model_name) + "_beta_" + str(beta_epoch) +
"_nmodel_" + str(m) +
"_epoch_" + str(epoch) + ".pt",
)
else:
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": opt.state_dict(),
"train_loss": np.mean(loss_this_epoch),
"valid_loss": loss,
"valid_mse": mse,
"valid_mean": y_pred[:, 0].flatten(),
"valid_sigma": y_pred[:, 1].flatten(),
"x_val": x_val,
"y_val": y_val,
},
path_to_model + "/" +
str(model_name) + "_nmodel_" +
str(m) + "_epoch_" +
str(epoch) + ".pt",
)
if save_final_checkpoint and (e % (EPOCHS-1) == 0) and (e != 0):
# option to just save final epoch
if loss_type == "bnll_loss":
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": opt.state_dict(),
"train_loss": np.mean(loss_this_epoch),
"valid_loss": loss,
"valid_mse": mse,
"valid_mean": y_pred[:, 0].flatten(),
"valid_sigma": y_pred[:, 1].flatten(),
"x_val": x_val,
"y_val": y_val,
},
path_to_model + "/" +
str(model_name) + "_beta_" + str(beta_epoch) +
"_nmodel_" + str(m) +
"_epoch_" + str(epoch) + ".pt",
)
else:
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": opt.state_dict(),
"train_loss": np.mean(loss_this_epoch),
"valid_loss": loss,
"valid_mse": mse,
"valid_mean": y_pred[:, 0].flatten(),
"valid_sigma": y_pred[:, 1].flatten(),
"x_val": x_val,
"y_val": y_val,
},
path_to_model + "/" +
str(model_name) + "_nmodel_" +
str(m) + "_epoch_" +
str(epoch) + ".pt",
)


model_ensemble.append(model)
final_mse.append(mse)
Expand All @@ -468,7 +527,7 @@ def train_DE(
print("start at", startTime, "end at", endTime)
print(endTime - startTime)

return model_ensemble, final_mse
return model_ensemble


if __name__ == "__main__":
Expand Down

0 comments on commit 14f9fa9

Please sign in to comment.