Skip to content

Commit

Permalink
Model1 with code sharing
Browse files Browse the repository at this point in the history
  • Loading branch information
JVGD committed Oct 24, 2021
1 parent eac8272 commit 149b965
Showing 1 changed file with 10 additions and 17 deletions.
27 changes: 10 additions & 17 deletions experiment1/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from experiment1.modules import Adder, Substracter


class Modular1(LightningModule):
class Model1(LightningModule):
def __init__(self):
"""Modular AI approach 1"""
super().__init__()
Expand All @@ -28,7 +28,7 @@ def forward(self, x):
y = self.weights_adder * y0 + self.weights_substracter * y1
return y

def training_step(self, batch, batch_idx, *args, **kwargs) -> T.Tensor:
def step(self, batch, batch_idx, *args, **kwargs) -> T.Tensor:
# Unpacking
samples = batch["samples"]
targets = batch["targets"]
Expand All @@ -38,22 +38,15 @@ def training_step(self, batch, batch_idx, *args, **kwargs) -> T.Tensor:

# Loss
loss = self.criteria(targets, targets_pred)
return loss

# Logging
self.log("loss/train", loss, prog_bar=True, on_step=False, on_epoch=True)
def training_step(self, batch, batch_idx, *args, **kwargs) -> T.Tensor:
# Loss
loss = self.step(batch, batch_idx, *args, *kwargs)
self.log("loss/train", loss)
return loss

def validation_step(self, batch, batch_idx, *args, **kwargs) -> T.Tensor:
# Unpacking
samples = batch["samples"]
targets = batch["targets"]

# Forward
targets_pred = self(samples)

def validation_step(self, batch, batch_idx, *args, **kwargs) -> None:
# Loss
loss = self.criteria(targets, targets_pred)

# Logging
self.log("loss/valid", loss, prog_bar=True, on_step=False, on_epoch=True)
return loss
loss = self.step(batch, batch_idx, *args, *kwargs)
self.log("loss/valid", loss)

0 comments on commit 149b965

Please sign in to comment.