From 158bac97ad514bd1750307daa9dd51d43450bae7 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 5 Apr 2022 11:11:55 +0200 Subject: [PATCH] load best model from checkpoint --- tools/train.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tools/train.py b/tools/train.py index f1a91871ad..41698c6aca 100644 --- a/tools/train.py +++ b/tools/train.py @@ -26,7 +26,7 @@ from anomalib.config import get_configurable_parameters from anomalib.data import get_datamodule from anomalib.models import get_model -from anomalib.utils.callbacks import get_callbacks +from anomalib.utils.callbacks import LoadModelCallback, get_callbacks from anomalib.utils.loggers import get_logger @@ -59,6 +59,11 @@ def train(): trainer = Trainer(**config.trainer, logger=logger, callbacks=callbacks) trainer.fit(model=model, datamodule=datamodule) + + # load best model from checkpoint before evaluating + load_model_callback = LoadModelCallback(weights_path=trainer.checkpoint_callback.best_model_path) + trainer.callbacks.insert(0, load_model_callback) + trainer.test(model=model, datamodule=datamodule)