diff --git a/classifier.py b/classifier.py index d5ec405..649c4cb 100644 --- a/classifier.py +++ b/classifier.py @@ -278,8 +278,11 @@ def train(args): model = model.to(device) lr = args.lr - # optimizer = AdamW(model.parameters(), lr=lr) - optimizer = SophiaG(model.parameters(), lr=lr, eps=1e-12, rho=0.03, betas=(0.985, 0.99), weight_decay=2e-1) + if args.optimizer == "adamw": + optimizer = AdamW(model.parameters(), lr=lr) + elif args.optimizer == "sophiag": + optimizer = SophiaG(model.parameters(), lr=lr, eps=1e-12, rho=0.03, betas=(0.985, 0.99), weight_decay=2e-1) + hess_interval = 10 iter_num = 0 @@ -419,9 +422,11 @@ def get_args(): parser.add_argument("--batch_size", help='sst: 64, cfimdb: 8 can fit a 12GB GPU', type=int, default=8) parser.add_argument("--hidden_dropout_prob", type=float, default=0.3) + parser.add_argument("--optimizer", type=str, default="adamw") args, _ = parser.parse_known_args() + # TODO: Possibly change defaults based on optimizer parser.add_argument("--lr", type=float, help="learning rate, default lr for 'pretrain': 1e-3, 'finetune': 1e-5", default=1e-5 if args.option == 'finetune' else 1e-3) diff --git a/multitask_classifier.py b/multitask_classifier.py index 8760ec4..541b9f1 100644 --- a/multitask_classifier.py +++ b/multitask_classifier.py @@ -9,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter from bert import BertModel -from optimizer import AdamW +from optimizer import AdamW, SophiaG from tqdm import tqdm from datasets import SentenceClassificationDataset, SentencePairDataset, \ @@ -177,7 +177,12 @@ def train_multitask(args): model = model.to(device) lr = args.lr - optimizer = AdamW(model.parameters(), lr=lr) + + if args.optimizer == "adamw": + optimizer = AdamW(model.parameters(), lr=lr) + elif args.optimizer == "sophiag": + optimizer = SophiaG(model.parameters(), lr=lr, eps=1e-12, rho=0.03, betas=(0.985, 0.99), weight_decay=2e-1) + best_dev_acc_para = 0 best_dev_acc_sst = 0 best_dev_acc_sts = 0 @@ -339,9 +344,15 @@ def get_args(): # hyper parameters parser.add_argument("--batch_size", help='sst: 64, cfimdb: 8 can fit a 12GB GPU', type=int, default=8) parser.add_argument("--hidden_dropout_prob", type=float, default=0.3) - parser.add_argument("--lr", type=float, help="learning rate, default lr for 'pretrain': 1e-3, 'finetune': 1e-5", - default=1e-5) + parser.add_argument("--optimizer", type=str, default="adamw") + + args, _ = parser.parse_known_args() + + # TODO: Possibly change defaults based on optimizer + parser.add_argument("--lr", type=float, help="learning rate, default lr for 'pretrain': 1e-3, 'finetune': 1e-5", + default=1e-5 if args.option == 'finetune' else 1e-3) + args = parser.parse_args() return args