Skip to content

Commit

Permalink
tests run locally, added config parameter for rs
Browse files Browse the repository at this point in the history
  • Loading branch information
beckynevin committed May 16, 2024
1 parent 16f5ac5 commit 862d079
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 51 deletions.
155 changes: 104 additions & 51 deletions src/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def train_DER(
start_epoch = 0

if set_and_save_rs:
print('setting and saving the rs')
# Set the random seed
set_random_seeds(seed_value=rs)

Expand Down Expand Up @@ -257,59 +258,111 @@ def train_DER(
plt.show()
plt.close()
if save_all_checkpoints:

torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": opt.state_dict(),
"train_loss": np.mean(loss_this_epoch),
"valid_loss": NIGloss_val,
"valid_mse": mse,
"med_u_al_validation": med_u_al_val,
"med_u_ep_validation": med_u_ep_val,
"std_u_al_validation": std_u_al_val,
"std_u_ep_validation": std_u_ep_val,
},
str(path_to_model)
+ "checkpoints/"
+ str(model_name)
+ "_loss_"
+ str(loss_type)
+ "_COEFF_"
+ str(COEFF)
+ "_epoch_"
+ str(epoch)
+ "_rs_"
+ str(rs)
+ ".pt",
)
if set_and_save_rs:
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": opt.state_dict(),
"train_loss": np.mean(loss_this_epoch),
"valid_loss": NIGloss_val,
"valid_mse": mse,
"med_u_al_validation": med_u_al_val,
"med_u_ep_validation": med_u_ep_val,
"std_u_al_validation": std_u_al_val,
"std_u_ep_validation": std_u_ep_val,
},
str(path_to_model)
+ "checkpoints/"
+ str(model_name)
+ "_loss_"
+ str(loss_type)
+ "_COEFF_"
+ str(COEFF)
+ "_epoch_"
+ str(epoch)
+ "_rs_"
+ str(rs)
+ ".pt",
)
else:
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": opt.state_dict(),
"train_loss": np.mean(loss_this_epoch),
"valid_loss": NIGloss_val,
"valid_mse": mse,
"med_u_al_validation": med_u_al_val,
"med_u_ep_validation": med_u_ep_val,
"std_u_al_validation": std_u_al_val,
"std_u_ep_validation": std_u_ep_val,
},
str(path_to_model)
+ "checkpoints/"
+ str(model_name)
+ "_loss_"
+ str(loss_type)
+ "_COEFF_"
+ str(COEFF)
+ "_epoch_"
+ str(epoch)
+ ".pt",
)
if save_final_checkpoint and (e % (EPOCHS - 1) == 0) and (e != 0):
# option to just save final epoch
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": opt.state_dict(),
"train_loss": np.mean(loss_this_epoch),
"valid_loss": NIGloss_val,
"valid_mse": mse,
"med_u_al_validation": med_u_al_val,
"med_u_ep_validation": med_u_ep_val,
"std_u_al_validation": std_u_al_val,
"std_u_ep_validation": std_u_ep_val,
},
str(path_to_model)
+ "checkpoints/"
+ str(model_name)
+ "_loss_"
+ str(loss_type)
+ "_COEFF_"
+ str(COEFF)
+ "_epoch_"
+ str(epoch)
+ ".pt",
)
if set_and_save_rs:
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": opt.state_dict(),
"train_loss": np.mean(loss_this_epoch),
"valid_loss": NIGloss_val,
"valid_mse": mse,
"med_u_al_validation": med_u_al_val,
"med_u_ep_validation": med_u_ep_val,
"std_u_al_validation": std_u_al_val,
"std_u_ep_validation": std_u_ep_val,
},
str(path_to_model)
+ "checkpoints/"
+ str(model_name)
+ "_loss_"
+ str(loss_type)
+ "_COEFF_"
+ str(COEFF)
+ "_epoch_"
+ str(epoch)
+ "_rs_"
+ str(rs)
+ ".pt",
)
else:
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": opt.state_dict(),
"train_loss": np.mean(loss_this_epoch),
"valid_loss": NIGloss_val,
"valid_mse": mse,
"med_u_al_validation": med_u_al_val,
"med_u_ep_validation": med_u_ep_val,
"std_u_al_validation": std_u_al_val,
"std_u_ep_validation": std_u_ep_val,
},
str(path_to_model)
+ "checkpoints/"
+ str(model_name)
+ "_loss_"
+ str(loss_type)
+ "_COEFF_"
+ str(COEFF)
+ "_epoch_"
+ str(epoch)
+ ".pt")
endTime = time.time()
if verbose:
print("start at", startTime, "end at", endTime)
Expand Down
2 changes: 2 additions & 0 deletions test/test_Aleatoric.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def create_test_config_DER(
"overwrite_final_checkpoint": True,
"plot": False,
"savefig": False,
"save_chk_random_seed_init": False,
"rs": 42,
"verbose": False,
},
"data": {
Expand Down
2 changes: 2 additions & 0 deletions test/test_DeepEvidentialRegression.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def create_test_config(
"overwrite_final_checkpoint": True,
"plot": False,
"savefig": True,
"save_chk_random_seed_init": False,
"rs": 42,
"verbose": False,
},
"data": {
Expand Down

0 comments on commit 862d079

Please sign in to comment.