diff --git a/benchmark/train/configs/financial.yaml b/benchmark/train/configs/financial.yaml index f98e5675..2641f339 100644 --- a/benchmark/train/configs/financial.yaml +++ b/benchmark/train/configs/financial.yaml @@ -45,3 +45,4 @@ optim: optimizer: adam base_lr: 0.001 max_epoch: 50 + early_stopping: False diff --git a/benchmark/train/main.py b/benchmark/train/main.py index c757b35c..4da419b2 100644 --- a/benchmark/train/main.py +++ b/benchmark/train/main.py @@ -5,6 +5,7 @@ import os import pytorch_lightning as pl from pytorch_lightning import loggers as pl_loggers +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from torch_geometric import seed_everything from torch_geometric.graphgym.cmd_args import parse_args from torch_geometric.graphgym.config import (load_cfg, get_fname) @@ -57,11 +58,30 @@ # Training gpus = 1 if cfg.device != 'cpu' and torch.cuda.is_available() else 0 + if cfg.optim.early_stopping: + patience = cfg.optim.patience + if patience is None: + patience = cfg.optim.max_epoch // 5 + early_stopping = EarlyStopping(monitor="val_loss", + min_delta=cfg.optim.min_delta, + patience=patience, mode="min") + ckpt = ModelCheckpoint(save_top_k=1, monitor='val_loss', + mode='min') + callbacks = [early_stopping, ckpt] + else: + callbacks = None + trainer = pl.Trainer(gpus=gpus, log_every_n_steps=1, - max_epochs=cfg.optim.max_epoch, logger=tb_logger) + max_epochs=cfg.optim.max_epoch, logger=tb_logger, + callbacks=callbacks) # visualization visualize_scalar_distribution( dataset[cfg.dataset.target_table].y, os.path.join(tb_logger.log_dir, "target_distribution.png"), ) trainer.fit(model, loaders[0], loaders[1]) + # test + print('Validation on best epoch:') + trainer.validate(ckpt_path='best', dataloaders=loaders[1]) + print('Test on best epoch:') + trainer.test(ckpt_path='best', dataloaders=loaders[2]) diff --git a/kumo/config/config.py b/kumo/config/config.py index 52df3957..990daa63 100644 --- a/kumo/config/config.py +++ b/kumo/config/config.py @@ -53,6 +53,12 @@ def set_cfg(cfg): # Restriction: the split column has to be in the prediction target table. cfg.dataset.split_column = None + # early stopping configs + cfg.optim.early_stopping = False + cfg.optim.min_delta = 0.001 + # if None, set patience = total num epochs / 10 + cfg.optim.patience = None + # Overwrite GraphGym scheduler # (might improve default optimizer after more training experiences # on databases) diff --git a/kumo/model/model_builder.py b/kumo/model/model_builder.py index 240c2324..16849d19 100644 --- a/kumo/model/model_builder.py +++ b/kumo/model/model_builder.py @@ -148,6 +148,30 @@ def training_step(self, batch, batch_idx): self._eval_regression(y, pred_value) return loss + @torch.no_grad() + def _eval_metric(self, pred_value, label, loss, prefix="val"): + """ Validation and test evaluation metrics. """ + if self.cfg.dataset.task_type == "classification": + pred_int = self._get_pred_int(pred_value) + acc = round( + accuracy_score(label.cpu(), pred_int.cpu()), + self.cfg.round, + ) + auc = self._eval_auroc(label, pred_value) + + self.log(f'{prefix}_acc', acc, on_step=False, on_epoch=True, + prog_bar=True) + if auc is not None: + self.log(f'{prefix}_auc', auc, on_step=False, on_epoch=True, + prog_bar=True) + elif self.cfg.dataset.task_type == "regression": + if self.current_epoch >= PLOT_EPOCH_THRESHOLD and \ + loss < self.val_best_loss: + self.plot_error_distribution(pred_value, label, metric="err", + prefix=prefix) + self.val_best_loss = loss + self._eval_regression(label, pred_value, prefix=prefix) + def validation_step(self, batch, batch_idx): """""" feat_dict, discrete_feat_dict, edge_index_dict = ( @@ -162,26 +186,24 @@ def validation_step(self, batch, batch_idx): mask) loss, pred_value = compute_loss(pred, y) - if self.cfg.dataset.task_type == "classification": - pred_int = self._get_pred_int(pred_value) - acc = round( - accuracy_score(y.detach().cpu(), pred_int.cpu()), - self.cfg.round, - ) - auc = self._eval_auroc(y, pred_value) + self._eval_metric(pred_value, y, loss, prefix="val") + return loss - self.log("val_acc", acc, on_step=False, on_epoch=True, - prog_bar=True) - if auc is not None: - self.log("val_auc", auc, on_step=False, on_epoch=True, - prog_bar=True) - elif self.cfg.dataset.task_type == "regression": - if self.current_epoch >= PLOT_EPOCH_THRESHOLD and \ - loss < self.val_best_loss: - self.plot_error_distribution(pred_value, y, metric="err", - prefix="val") - self.val_best_loss = loss - self._eval_regression(y, pred_value, prefix="val") + def test_step(self, batch, batch_idx): + """""" + feat_dict, discrete_feat_dict, edge_index_dict = ( + batch.feat_dict, + batch.discrete_feat_dict, + batch.edge_index_dict, + ) + y = batch[self.target_table].y + mask = batch[self.target_table].test_mask + + pred, y = self.model(feat_dict, edge_index_dict, discrete_feat_dict, y, + mask) + loss, pred_value = compute_loss(pred, y) + + self._eval_metric(pred_value, y, loss, prefix="test") return loss def _configure_scheduler(self, optimizer, scheduler, min_lr=0.0001):