From 149b965431ac4851448ad45c07ed258165eb2690 Mon Sep 17 00:00:00 2001 From: Javier Vargas Date: Sun, 24 Oct 2021 09:28:30 +0200 Subject: [PATCH] Model1 with code sharing --- experiment1/model.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/experiment1/model.py b/experiment1/model.py index 2f03a8a..98e798d 100644 --- a/experiment1/model.py +++ b/experiment1/model.py @@ -6,7 +6,7 @@ from experiment1.modules import Adder, Substracter -class Modular1(LightningModule): +class Model1(LightningModule): def __init__(self): """Modular AI approach 1""" super().__init__() @@ -28,7 +28,7 @@ def forward(self, x): y = self.weights_adder * y0 + self.weights_substracter * y1 return y - def training_step(self, batch, batch_idx, *args, **kwargs) -> T.Tensor: + def step(self, batch, batch_idx, *args, **kwargs) -> T.Tensor: # Unpacking samples = batch["samples"] targets = batch["targets"] @@ -38,22 +38,15 @@ def training_step(self, batch, batch_idx, *args, **kwargs) -> T.Tensor: # Loss loss = self.criteria(targets, targets_pred) + return loss - # Logging - self.log("loss/train", loss, prog_bar=True, on_step=False, on_epoch=True) + def training_step(self, batch, batch_idx, *args, **kwargs) -> T.Tensor: + # Loss + loss = self.step(batch, batch_idx, *args, *kwargs) + self.log("loss/train", loss) return loss - def validation_step(self, batch, batch_idx, *args, **kwargs) -> T.Tensor: - # Unpacking - samples = batch["samples"] - targets = batch["targets"] - - # Forward - targets_pred = self(samples) - + def validation_step(self, batch, batch_idx, *args, **kwargs) -> None: # Loss - loss = self.criteria(targets, targets_pred) - - # Logging - self.log("loss/valid", loss, prog_bar=True, on_step=False, on_epoch=True) - return loss \ No newline at end of file + loss = self.step(batch, batch_idx, *args, *kwargs) + self.log("loss/valid", loss)