From eac827256003839240dd71214e2b3cd132cbac4e Mon Sep 17 00:00:00 2001 From: Javier Vargas Date: Sat, 23 Oct 2021 16:54:47 +0200 Subject: [PATCH] Added validation step --- experiment1/model.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/experiment1/model.py b/experiment1/model.py index 8e66786..2f03a8a 100644 --- a/experiment1/model.py +++ b/experiment1/model.py @@ -43,3 +43,17 @@ def training_step(self, batch, batch_idx, *args, **kwargs) -> T.Tensor: self.log("loss/train", loss, prog_bar=True, on_step=False, on_epoch=True) return loss + def validation_step(self, batch, batch_idx, *args, **kwargs) -> T.Tensor: + # Unpacking + samples = batch["samples"] + targets = batch["targets"] + + # Forward + targets_pred = self(samples) + + # Loss + loss = self.criteria(targets, targets_pred) + + # Logging + self.log("loss/valid", loss, prog_bar=True, on_step=False, on_epoch=True) + return loss \ No newline at end of file