Skip to content

Commit

Permalink
Test performance evaluation and early stopping (rusty1s#152)
Browse files Browse the repository at this point in the history
early stopping, ckpt, test set

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
RexYing and rusty1s authored Jan 10, 2022
1 parent 4311233 commit cd917a7
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 20 deletions.
1 change: 1 addition & 0 deletions benchmark/train/configs/financial.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,4 @@ optim:
optimizer: adam
base_lr: 0.001
max_epoch: 50
early_stopping: False
22 changes: 21 additions & 1 deletion benchmark/train/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch_geometric import seed_everything
from torch_geometric.graphgym.cmd_args import parse_args
from torch_geometric.graphgym.config import (load_cfg, get_fname)
Expand Down Expand Up @@ -57,11 +58,30 @@

# Training
gpus = 1 if cfg.device != 'cpu' and torch.cuda.is_available() else 0
if cfg.optim.early_stopping:
patience = cfg.optim.patience
if patience is None:
patience = cfg.optim.max_epoch // 5
early_stopping = EarlyStopping(monitor="val_loss",
min_delta=cfg.optim.min_delta,
patience=patience, mode="min")
ckpt = ModelCheckpoint(save_top_k=1, monitor='val_loss',
mode='min')
callbacks = [early_stopping, ckpt]
else:
callbacks = None

trainer = pl.Trainer(gpus=gpus, log_every_n_steps=1,
max_epochs=cfg.optim.max_epoch, logger=tb_logger)
max_epochs=cfg.optim.max_epoch, logger=tb_logger,
callbacks=callbacks)
# visualization
visualize_scalar_distribution(
dataset[cfg.dataset.target_table].y,
os.path.join(tb_logger.log_dir, "target_distribution.png"),
)
trainer.fit(model, loaders[0], loaders[1])
# test
print('Validation on best epoch:')
trainer.validate(ckpt_path='best', dataloaders=loaders[1])
print('Test on best epoch:')
trainer.test(ckpt_path='best', dataloaders=loaders[2])
6 changes: 6 additions & 0 deletions kumo/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ def set_cfg(cfg):
# Restriction: the split column has to be in the prediction target table.
cfg.dataset.split_column = None

# early stopping configs
cfg.optim.early_stopping = False
cfg.optim.min_delta = 0.001
# if None, set patience = total num epochs / 10
cfg.optim.patience = None

# Overwrite GraphGym scheduler
# (might improve default optimizer after more training experiences
# on databases)
Expand Down
60 changes: 41 additions & 19 deletions kumo/model/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,30 @@ def training_step(self, batch, batch_idx):
self._eval_regression(y, pred_value)
return loss

@torch.no_grad()
def _eval_metric(self, pred_value, label, loss, prefix="val"):
""" Validation and test evaluation metrics. """
if self.cfg.dataset.task_type == "classification":
pred_int = self._get_pred_int(pred_value)
acc = round(
accuracy_score(label.cpu(), pred_int.cpu()),
self.cfg.round,
)
auc = self._eval_auroc(label, pred_value)

self.log(f'{prefix}_acc', acc, on_step=False, on_epoch=True,
prog_bar=True)
if auc is not None:
self.log(f'{prefix}_auc', auc, on_step=False, on_epoch=True,
prog_bar=True)
elif self.cfg.dataset.task_type == "regression":
if self.current_epoch >= PLOT_EPOCH_THRESHOLD and \
loss < self.val_best_loss:
self.plot_error_distribution(pred_value, label, metric="err",
prefix=prefix)
self.val_best_loss = loss
self._eval_regression(label, pred_value, prefix=prefix)

def validation_step(self, batch, batch_idx):
""""""
feat_dict, discrete_feat_dict, edge_index_dict = (
Expand All @@ -162,26 +186,24 @@ def validation_step(self, batch, batch_idx):
mask)
loss, pred_value = compute_loss(pred, y)

if self.cfg.dataset.task_type == "classification":
pred_int = self._get_pred_int(pred_value)
acc = round(
accuracy_score(y.detach().cpu(), pred_int.cpu()),
self.cfg.round,
)
auc = self._eval_auroc(y, pred_value)
self._eval_metric(pred_value, y, loss, prefix="val")
return loss

self.log("val_acc", acc, on_step=False, on_epoch=True,
prog_bar=True)
if auc is not None:
self.log("val_auc", auc, on_step=False, on_epoch=True,
prog_bar=True)
elif self.cfg.dataset.task_type == "regression":
if self.current_epoch >= PLOT_EPOCH_THRESHOLD and \
loss < self.val_best_loss:
self.plot_error_distribution(pred_value, y, metric="err",
prefix="val")
self.val_best_loss = loss
self._eval_regression(y, pred_value, prefix="val")
def test_step(self, batch, batch_idx):
""""""
feat_dict, discrete_feat_dict, edge_index_dict = (
batch.feat_dict,
batch.discrete_feat_dict,
batch.edge_index_dict,
)
y = batch[self.target_table].y
mask = batch[self.target_table].test_mask

pred, y = self.model(feat_dict, edge_index_dict, discrete_feat_dict, y,
mask)
loss, pred_value = compute_loss(pred, y)

self._eval_metric(pred_value, y, loss, prefix="test")
return loss

def _configure_scheduler(self, optimizer, scheduler, min_lr=0.0001):
Expand Down

0 comments on commit cd917a7

Please sign in to comment.