diff --git a/SeasonTST/SeasonTST_finetune.py b/SeasonTST/SeasonTST_finetune.py index 233db0f7..1e4ec229 100644 --- a/SeasonTST/SeasonTST_finetune.py +++ b/SeasonTST/SeasonTST_finetune.py @@ -39,7 +39,7 @@ datefmt="%m/%d/%Y %I:%M:%S %p", filename=f'logs/{datetime.datetime.now().strftime("%Y_%m_%d_%I:%M")}_finetune.log', encoding="utf-8", - level=logging.DEBUG, + level=logging.INFO, ) @@ -51,10 +51,6 @@ def finetune_func(learner, save_path, args, lr=0.001): print("end-to-end finetuning") - if not os.path.exists(save_path): - os.makedirs(save_path) - - print(save_path) # fit the data to the model and save learner.fine_tune( n_epochs=args.n_epochs_finetune, base_lr=lr, freeze_epochs=args.freeze_epochs @@ -107,20 +103,6 @@ def save_recorders(learner, args): ) -def test_func(weight_path, learner, args, dls): - - out = learner.test( - dls.test, weight_path=weight_path, scores=[mse, mae] - ) # out: a list of [pred, targ, score] - print("score:", out[2]) - # save results - pd.DataFrame(np.array(out[2]).reshape(1, -1), columns=["mse", "mae"]).to_csv( - args.save_path + args.save_finetuned_model + "_acc.csv", - float_format="%.6f", - index=False, - ) - return out - def load_config(): @@ -135,13 +117,14 @@ def load_config(): "revin": 0, # reversible instance normalization "mask_ratio": 0.4, # masking ratio for the input "lr": 1e-3, - "batch_size": 128, + "batch_size": 64, + "drop_last": False, "num_workers": 6, "prefetch_factor": 3, - "n_epochs_pretrain": 1, # number of pre-training epochs, + "n_epochs_pretrain": 20, # number of pre-training epochs, "freeze_epochs": 0, - "n_epochs_finetune": 250, - "pretrained_model_id": 2500, # id of the saved pretrained model + "n_epochs_finetune": 1, + "pretrained_model_id": 2, # id of the saved pretrained model "save_finetuned_model": "./finetuned_d128", "save_path": "saved_models" + "/masked_patchtst/", } @@ -186,15 +169,15 @@ def main(): # Create dataloader dls = get_dls(config_obj, SeasonTST_Dataset, data, mask) - # suggested_lr = find_lr(config_obj, dls) # This is what I got on a small dataset. In case one wants to skip this for testing. - suggested_lr = 0.00017073526474706903 + suggested_lr = 0.0002 # 0.000298364724028334 + learner = get_learner(config_obj, dls, suggested_lr, model) + suggested_lr = learner.lr_finder() print(suggested_lr) - learner = get_learner(config_obj, dls, suggested_lr, model) # This function will save the model weights to config_obj.save_finetuned_model. ie will not overwrite the pretrained model. - # However, there is currently no set-up to do finetuning from the result of a previous finetuning. + # To continue training from a previous fine-tuning checkpoint, the path needs to be explicity fed to the get_model function finetune_func(learner, pretrained_model_path, config_obj, suggested_lr) diff --git a/SeasonTST/SeasonTST_pretrain.py b/SeasonTST/SeasonTST_pretrain.py index a84c9190..6032c448 100644 --- a/SeasonTST/SeasonTST_pretrain.py +++ b/SeasonTST/SeasonTST_pretrain.py @@ -29,7 +29,7 @@ datefmt="%m/%d/%Y %I:%M:%S %p", filename=f'logs/{datetime.datetime.now().strftime("%Y_%m_%d_%I_%M")}_train.log', encoding="utf-8", - level=logging.DEBUG, + level=logging.INFO, ) @@ -95,10 +95,11 @@ def load_config(): "mask_value": -99, # Value to assign to masked elements of data input "lr": 1e-3, "batch_size": 128, + "drop_last":True, "prefetch_factor": 3, "num_workers": 6, - "n_epochs_pretrain": 1, # number of pre-training epochs - "pretrained_model_id": 2500, # id of the saved pretrained model + "n_epochs_pretrain": 20, # number of pre-training epochs + "pretrained_model_id": 2, # id of the saved pretrained model } config_obj = SimpleNamespace(**config) @@ -109,37 +110,42 @@ def main(): data, mask = load_data() config_obj = load_config() + save_path = "saved_models" + "/masked_patchtst/" + pretrained_model = ( + "patchtst_pretrained_cw" + + str(config_obj.sequence_length) + + "_patch" + + str(config_obj.patch_len) + + "_stride" + + str(config_obj.stride) + + "_epochs-pretrain" + + str(config_obj.n_epochs_pretrain) + + "_mask" + + str(config_obj.mask_ratio) + + "_model" + + str(config_obj.pretrained_model_id) + ) + pretrained_model_path = save_path + pretrained_model + ".pth" + # Creates train valid and test datasets for one epoch. Notice that they are in different locations! dls = get_dls(config_obj, SeasonTST_Dataset, data, mask) - model = get_model(config_obj) + + model = get_model( + config_obj, headtype="pretrain", weights_path=pretrained_model_path, exclude_head=False + ) # suggested_lr = find_lr(config_obj, dls) # This is what I got on a small dataset. In case one wants to skip this for testing. suggested_lr = 0.00020565123083486514 - save_pretrained_model = ( - "patchtst_pretrained_cw" - + str(config_obj.sequence_length) - + "_patch" - + str(config_obj.patch_len) - + "_stride" - + str(config_obj.stride) - + "_epochs-pretrain" - + str(config_obj.n_epochs_pretrain) - + "_mask" - + str(config_obj.mask_ratio) - + "_model" - + str(config_obj.pretrained_model_id) - ) - save_path = "saved_models" + "/masked_patchtst/" + + pretrain_func( - save_pretrained_model, save_path, config_obj, model, dls, suggested_lr + pretrained_model, save_path, config_obj, model, dls, suggested_lr ) - pretrained_model_name = save_path + save_pretrained_model + ".pth" - - model = transfer_weights(pretrained_model_name, model) + model = transfer_weights(pretrained_model_path, model) if __name__ == "__main__": diff --git a/SeasonTST/utils.py b/SeasonTST/utils.py index 9e64961b..57343c7b 100644 --- a/SeasonTST/utils.py +++ b/SeasonTST/utils.py @@ -75,17 +75,14 @@ def get_model(config, headtype="pretrain", weights_path=None, exclude_head=True) return model -def find_lr(config_obj, dls): +def find_lr(model, config_obj, dls): """ # This method typically involves training the model for a few epochs with a range of learning rates and recording the loss at each step. The learning rate that gives the fastest decrease in loss is considered optimal or near-optimal for the training process. - :param config_obj: - :return: """ - model = get_model(config_obj) # get loss loss_func = torch.nn.MSELoss(reduction="mean") # get callbacks