From 14f9fa9cc51d808cca63f9030c10e0a89143cd5f Mon Sep 17 00:00:00 2001 From: beckynevin Date: Wed, 20 Mar 2024 10:07:03 -0600 Subject: [PATCH] train DE mods to accept beta and other args, option to save only final chk --- src/scripts/models.py | 3 +- src/scripts/train.py | 129 ++++++++++++++++++++++++++++++------------ 2 files changed, 95 insertions(+), 37 deletions(-) diff --git a/src/scripts/models.py b/src/scripts/models.py index 8bf14f4..fda698c 100644 --- a/src/scripts/models.py +++ b/src/scripts/models.py @@ -74,10 +74,9 @@ def model_setup_DE(loss_type, DEVICE):#, INIT_LR=0.001): #model = de_var().to(DEVICE) Layer = MuVarLayer lossFn = loss_bnll - #opt = torch.optim.Adam(model.parameters(), lr=INIT_LR) model = torch.nn.Sequential(Model(2), Layer()) model = model.to(DEVICE) - return model, lossFn#, opt + return model, lossFn class de_no_var(nn.Module): diff --git a/src/scripts/train.py b/src/scripts/train.py index 7a5f960..c4172bf 100644 --- a/src/scripts/train.py +++ b/src/scripts/train.py @@ -189,9 +189,11 @@ def train_DE( loss_type, n_models, model_name="DE", + BETA=None, EPOCHS=100, path_to_model="models/", - save_checkpoints=False, + save_all_checkpoints=False, + save_final_checkpoint=False, plot=True, savefig=True, verbose=True, @@ -201,7 +203,7 @@ def train_DE( ''' # Find last epoch saved - if save_checkpoints is True: + if save_checkpoints: print('looking for saved checkpts', glob.glob("models/*" + model_name + "*")) @@ -249,7 +251,7 @@ def train_DE( print("epoch", epoch, round(e / EPOCHS, 2)) loss_this_epoch = [] - if plot is True: + if plot: plt.clf() fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), gridspec_kw={'height_ratios': [3, 1]} @@ -282,13 +284,13 @@ def train_DE( beta_epoch = 0.5 # 1 - e / EPOCHS # this one doesn't work great ''' - beta_epoch = 1 - e / EPOCHS - #beta_epoch = 1.0 + #beta_epoch = 1 - e / EPOCHS + beta_epoch = BETA loss = lossFn(pred[:, 0].flatten(), pred[:, 1].flatten(), y, beta=beta_epoch) - if plot is True: + if plot: if (e % (EPOCHS-1) == 0) and (e != 0): if loss_type == "no_var_loss": ax1.scatter(y, pred.flatten().detach().numpy(), @@ -356,7 +358,7 @@ def train_DE( print("new best loss", loss, "in epoch", epoch) # best_weights = copy.deepcopy(model.state_dict()) # print('validation loss', mse) - if (plot is True) and (e % (EPOCHS-1) == 0) and (e != 0): + if (plot or savefig) and (e % (EPOCHS-1) == 0) and (e != 0): ax1.plot(range(0, 1000), range(0, 1000), color='black', @@ -405,8 +407,6 @@ def train_DE( ax2.axhline(0, color='black', linestyle='--', linewidth=1) ax2.set_ylabel("Residuals") ax2.set_xlabel("True Value") - - # add annotion for loss value if loss_type == "bnll_loss": ax1.annotate(r'$\beta = $' + @@ -421,44 +421,103 @@ def train_DE( ax1.annotate(str(loss_type) + ' = ' + str(round(loss,2)) + '\n' + r'MSE = ' + str(round(mse,2)), xy=(0.73, 0.1), xycoords='axes fraction', - bbox=dict(boxstyle='round,pad=0.5', facecolor='lightgrey', alpha=0.5)) - - - + bbox=dict(boxstyle='round,pad=0.5', facecolor='lightgrey', alpha=0.5)) ax1.set_ylabel("Prediction") ax1.set_title("Epoch " + str(e)) ax1.set_xlim([0, 1000]) ax1.set_ylim([0, 1000]) ax1.legend() - if savefig is True: + if savefig: ax1.errorbar(200, 600, yerr=5, color='red', capsize=2) plt.savefig("../images/animations/" + str(model_name) + "_nmodel_" + str(m) + "_beta_" + str(beta_epoch) + "_epoch_" + str(epoch) + ".png") - plt.show() + if plot: + plt.show() plt.close() - if save_checkpoints is True: - - torch.save( - { - "epoch": epoch, - "model_state_dict": model.state_dict(), - "optimizer_state_dict": opt.state_dict(), - "train_loss": np.mean(loss_this_epoch), - "valid_loss": loss, - "valid_mean": y_pred[:, 0].flatten(), - "valid_sigma": y_pred[:, 1].flatten(), - "x_val": x_val, - "y_val": y_val, - }, - path_to_model + "/" + - str(model_name) + "_nmodel_" + - str(m) + "_epoch_" + - str(epoch) + ".pt", - ) + if save_all_checkpoints: + if loss_type == "bnll_loss": + torch.save( + { + "epoch": epoch, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": opt.state_dict(), + "train_loss": np.mean(loss_this_epoch), + "valid_loss": loss, + "valid_mse": mse, + "valid_mean": y_pred[:, 0].flatten(), + "valid_sigma": y_pred[:, 1].flatten(), + "x_val": x_val, + "y_val": y_val, + }, + path_to_model + "/" + + str(model_name) + "_beta_" + str(beta_epoch) + + "_nmodel_" + str(m) + + "_epoch_" + str(epoch) + ".pt", + ) + else: + torch.save( + { + "epoch": epoch, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": opt.state_dict(), + "train_loss": np.mean(loss_this_epoch), + "valid_loss": loss, + "valid_mse": mse, + "valid_mean": y_pred[:, 0].flatten(), + "valid_sigma": y_pred[:, 1].flatten(), + "x_val": x_val, + "y_val": y_val, + }, + path_to_model + "/" + + str(model_name) + "_nmodel_" + + str(m) + "_epoch_" + + str(epoch) + ".pt", + ) + if save_final_checkpoint and (e % (EPOCHS-1) == 0) and (e != 0): + # option to just save final epoch + if loss_type == "bnll_loss": + torch.save( + { + "epoch": epoch, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": opt.state_dict(), + "train_loss": np.mean(loss_this_epoch), + "valid_loss": loss, + "valid_mse": mse, + "valid_mean": y_pred[:, 0].flatten(), + "valid_sigma": y_pred[:, 1].flatten(), + "x_val": x_val, + "y_val": y_val, + }, + path_to_model + "/" + + str(model_name) + "_beta_" + str(beta_epoch) + + "_nmodel_" + str(m) + + "_epoch_" + str(epoch) + ".pt", + ) + else: + torch.save( + { + "epoch": epoch, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": opt.state_dict(), + "train_loss": np.mean(loss_this_epoch), + "valid_loss": loss, + "valid_mse": mse, + "valid_mean": y_pred[:, 0].flatten(), + "valid_sigma": y_pred[:, 1].flatten(), + "x_val": x_val, + "y_val": y_val, + }, + path_to_model + "/" + + str(model_name) + "_nmodel_" + + str(m) + "_epoch_" + + str(epoch) + ".pt", + ) + model_ensemble.append(model) final_mse.append(mse) @@ -468,7 +527,7 @@ def train_DE( print("start at", startTime, "end at", endTime) print(endTime - startTime) - return model_ensemble, final_mse + return model_ensemble if __name__ == "__main__":