From 2b0d658235f4d6a20a36309209e9637556cd83dd Mon Sep 17 00:00:00 2001 From: tsugumi-sys Date: Tue, 2 Jan 2024 15:47:13 +0900 Subject: [PATCH] Fix broken evaluator --- pipeline/evaluator.py | 11 ++++++----- pipeline/moving_mnist/convlstm.py | 2 +- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pipeline/evaluator.py b/pipeline/evaluator.py index 42ed7fe..7501b6c 100644 --- a/pipeline/evaluator.py +++ b/pipeline/evaluator.py @@ -9,6 +9,8 @@ from torch.utils.data import DataLoader from torchvision import transforms +from core.constants import DEVICE + class Evaluator: def __init__( @@ -25,11 +27,10 @@ def __init__( def run(self): with torch.no_grad(): - for batch_idx, (input_frames, label_frames) in enumerate( - self.test_dataloader - ): - pred_frames = self.model(input_frames) - self.visualize_predlabel_frames(batch_idx, label_frames, pred_frames) + for batch_idx, (input, label) in enumerate(self.test_dataloader): + input, label = input.to(DEVICE), label.to(DEVICE) + pred_frames = self.model(input) + self.visualize_predlabel_frames(batch_idx, label, pred_frames) self.visualize_attention_maps(batch_idx) def visualize_predlabel_frames( diff --git a/pipeline/moving_mnist/convlstm.py b/pipeline/moving_mnist/convlstm.py index 8cc7997..69d9ab6 100644 --- a/pipeline/moving_mnist/convlstm.py +++ b/pipeline/moving_mnist/convlstm.py @@ -83,7 +83,7 @@ def main(): ### print("Evaluating ...") evaluator = Evaluator( - model=None, + model=model, test_dataloader=data_loaders.test_dataloader, save_dir_path="./tmp/evaluate", )