diff --git a/experiment1/model.py b/experiment1/model.py index 79d3f90..b8617f0 100644 --- a/experiment1/model.py +++ b/experiment1/model.py @@ -35,8 +35,7 @@ def forward(self, x): def step(self, batch, batch_idx, *args, **kwargs) -> T.Tensor: # Unpacking - samples = batch["sample"] - targets = batch["target"] + samples, targets = batch # Forward targets_pred = self(samples)