diff --git a/frame_semantic_transformer/train.py b/frame_semantic_transformer/train.py index bc84625..f867f97 100644 --- a/frame_semantic_transformer/train.py +++ b/frame_semantic_transformer/train.py @@ -161,7 +161,7 @@ def training_epoch_end(self, training_step_outputs: list[Any]) -> None: self.model.save_pretrained(path) def validation_epoch_end(self, validation_step_outputs: list[Any]) -> None: - losses = [out["losses"].cpu() for out in validation_step_outputs] + losses = [out["loss"].cpu() for out in validation_step_outputs] self.average_validation_loss = np.round( torch.mean(torch.stack(losses)).item(), 4,